diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 7a58e4fc2f984..66d0fc624e8f1 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -83,9 +83,10 @@ class ConvertToLLVMPattern : public ConversionPattern { // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. - Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, - ValueRange indices, - ConversionPatternRewriter &rewriter) const; + Value getStridedElementPtr( + ConversionPatternRewriter &rewriter, Location loc, MemRefType type, + Value memRefDesc, ValueRange indices, + LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const; /// Returns if the given memref type is convertible to LLVM and has an /// identity layout map. diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index f34b5b46cab50..03637eedc495a 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1187,7 +1187,12 @@ def LoadOp : MemRef_Op<"load", The `load` op reads an element from a memref at the specified indices. The number of indices must match the rank of the memref. The indices must - be in-bounds: `0 <= idx < dim_size` + be in-bounds: `0 <= idx < dim_size`. + + Lowerings of `memref.load` may emit attributes, e.g. `inbouds` + `nuw` + when converting to LLVM's `llvm.getelementptr`, that would cause undefined + behavior if indices are out of bounds or if computing the offset in the + memref would cause signed overflow of the `index` type. The single result of `memref.load` is a value with the same type as the element type of the memref. @@ -1881,7 +1886,12 @@ def MemRef_StoreOp : MemRef_Op<"store", The `store` op stores an element into a memref at the specified indices. The number of indices must match the rank of the memref. The indices must - be in-bounds: `0 <= idx < dim_size` + be in-bounds: `0 <= idx < dim_size`. + + Lowerings of `memref.store` may emit attributes, e.g. `inbouds` + `nuw` + when converting to LLVM's `llvm.getelementptr`, that would cause undefined + behavior if indices are out of bounds or if computing the offset in the + memref would cause signed overflow of the `index` type. A set `nontemporal` attribute indicates that this store is not expected to be reused in the cache. For details, refer to the diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 6e596485cbb58..ff462033462b2 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1118,10 +1118,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern { if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4) return op.emitOpError("chipset unsupported element size"); - Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(), - (adaptor.getSrcIndices()), rewriter); - Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(), - (adaptor.getDstIndices()), rewriter); + Value srcPtr = + getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), + (adaptor.getSrcIndices())); + Value dstPtr = + getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(), + (adaptor.getDstIndices())); rewriter.replaceOpWithNewOp( op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth), diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 417555792b44f..0c3f942b5cbd9 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -299,9 +299,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { auto sliceIndexI64 = rewriter.create( loc, rewriter.getI64Type(), sliceIndex); return getStridedElementPtr( - loc, llvm::cast(tileMemory.getType()), - descriptor.getResult(0), {sliceIndexI64, zero}, - static_cast(rewriter)); + static_cast(rewriter), loc, + llvm::cast(tileMemory.getType()), descriptor.getResult(0), + {sliceIndexI64, zero}); } /// Emits an in-place swap of a slice of a tile in ZA and a slice of a @@ -507,9 +507,9 @@ struct LoadTileSliceConversion if (!tileId) return failure(); - Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(), - adaptor.getBase(), - adaptor.getIndices(), rewriter); + Value ptr = this->getStridedElementPtr( + rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(), + adaptor.getIndices()); auto tileSlice = loadTileSliceOp.getTileSliceIndex(); @@ -554,8 +554,8 @@ struct StoreTileSliceConversion // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice. Value ptr = this->getStridedElementPtr( - loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(), - adaptor.getIndices(), rewriter); + rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(), + adaptor.getIndices()); auto tileSlice = storeTileSliceOp.getTileSliceIndex(); diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 4bd94bcebf290..45fd933d58857 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -122,8 +122,9 @@ struct WmmaLoadOpToNVVMLowering // Create nvvm.mma_load op according to the operand types. Value dataPtr = getStridedElementPtr( - loc, cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), - adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); + rewriter, loc, + cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), + adaptor.getSrcMemref(), adaptor.getIndices()); Value leadingDim = rewriter.create( loc, rewriter.getI32Type(), @@ -177,9 +178,9 @@ struct WmmaStoreOpToNVVMLowering } Value dataPtr = getStridedElementPtr( - loc, + rewriter, loc, cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()), - adaptor.getDstMemref(), adaptor.getIndices(), rewriter); + adaptor.getDstMemref(), adaptor.getIndices()); Value leadingDim = rewriter.create( loc, rewriter.getI32Type(), subgroupMmaStoreMatrixOp.getLeadDimensionAttr()); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 0505214de2015..6942a64048722 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -59,8 +59,9 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, } Value ConvertToLLVMPattern::getStridedElementPtr( - Location loc, MemRefType type, Value memRefDesc, ValueRange indices, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, Location loc, MemRefType type, + Value memRefDesc, ValueRange indices, + LLVM::GEPNoWrapFlags noWrapFlags) const { auto [strides, offset] = type.getStridesAndOffset(); @@ -91,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr( return index ? rewriter.create( loc, elementPtrType, getTypeConverter()->convertType(type.getElementType()), - base, index) + base, index, noWrapFlags) : base; } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 158de6dea58c9..3250ea9df3d38 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -35,6 +35,9 @@ namespace mlir { using namespace mlir; +static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags = + LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw; + namespace { static bool isStaticStrideOrOffset(int64_t strideOrOffset) { @@ -420,8 +423,8 @@ struct AssumeAlignmentOpLowering auto loc = op.getLoc(); auto srcMemRefType = cast(op.getMemref().getType()); - Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, - rewriter); + Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref, + /*indices=*/{}); // Emit llvm.assume(true) ["align"(memref, alignment)]. // This is more direct than ptrtoint-based checks, is explicitly supported, @@ -644,8 +647,8 @@ struct GenericAtomicRMWOpLowering // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = cast(atomicOp.getMemref().getType()); - auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + auto dataPtr = getStridedElementPtr( + rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices()); Value init = rewriter.create( loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); rewriter.create(loc, init, loopBlock); @@ -829,9 +832,12 @@ struct LoadOpLowering : public LoadStoreOpLowering { ConversionPatternRewriter &rewriter) const override { auto type = loadOp.getMemRefType(); - Value dataPtr = - getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + // Per memref.load spec, the indices must be in-bounds: + // 0 <= idx < dim_size, and additionally all offsets are non-negative, + // hence inbounds and nuw are used when lowering to llvm.getelementptr. + Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type, + adaptor.getMemref(), + adaptor.getIndices(), kNoWrapFlags); rewriter.replaceOpWithNewOp( loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, false, loadOp.getNontemporal()); @@ -849,8 +855,12 @@ struct StoreOpLowering : public LoadStoreOpLowering { ConversionPatternRewriter &rewriter) const override { auto type = op.getMemRefType(); - Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + // Per memref.store spec, the indices must be in-bounds: + // 0 <= idx < dim_size, and additionally all offsets are non-negative, + // hence inbounds and nuw are used when lowering to llvm.getelementptr. + Value dataPtr = + getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(), + adaptor.getIndices(), kNoWrapFlags); rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr, 0, false, op.getNontemporal()); return success(); @@ -868,8 +878,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering { auto type = prefetchOp.getMemRefType(); auto loc = prefetchOp.getLoc(); - Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + Value dataPtr = getStridedElementPtr( + rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices()); // Replace with llvm.prefetch. IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()); @@ -1809,8 +1819,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering { if (failed(memRefType.getStridesAndOffset(strides, offset))) return failure(); auto dataPtr = - getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType, + adaptor.getMemref(), adaptor.getIndices()); rewriter.replaceOpWithNewOp( atomicOp, *maybeKind, dataPtr, adaptor.getValue(), LLVM::AtomicOrdering::acq_rel); diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 69fa62c8196e4..eb3558d2460e4 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { auto srcMemrefType = cast(op.getSrcMemref().getType()); Value srcPtr = - getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), - adaptor.getIndices(), rewriter); + getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, + adaptor.getSrcMemref(), adaptor.getIndices()); Value ldMatrixResult = b.create( ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), @@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering Location loc = op.getLoc(); auto dstMemrefType = cast(op.getDst().getType()); Value dstPtr = - getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(), - adaptor.getDstIndices(), rewriter); + getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType, + adaptor.getDst(), adaptor.getDstIndices()); FailureOr dstAddressSpace = getTypeConverter()->getMemRefAddressSpace(dstMemrefType); if (failed(dstAddressSpace)) @@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering return rewriter.notifyMatchFailure( loc, "source memref address space not convertible to integer"); - Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(), - adaptor.getSrcIndices(), rewriter); + Value scrPtr = + getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(), + adaptor.getSrcIndices()); // Intrinsics takes a global pointer so we need an address space cast. auto srcPointerGlobalType = LLVM::LLVMPointerType::get( op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); @@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern { MemRefType mbarrierMemrefType = nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType); return ConvertToLLVMPattern::getStridedElementPtr( - b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter); + rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}); } }; @@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op->getLoc(), rewriter); auto srcMemrefType = cast(op.getDst().getType()); - Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, - adaptor.getDst(), {}, rewriter); + Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType, + adaptor.getDst(), {}); Value barrier = getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); @@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op->getLoc(), rewriter); auto srcMemrefType = cast(op.getSrc().getType()); - Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, - adaptor.getSrc(), {}, rewriter); + Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType, + adaptor.getSrc(), {}); SmallVector coords = adaptor.getCoordinates(); for (auto [index, value] : llvm::enumerate(coords)) { coords[index] = truncToI32(b, value); @@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering Value leadDim = makeConst(leadDimVal); Value baseAddr = getStridedElementPtr( - op->getLoc(), cast(op.getTensor().getType()), - adaptor.getTensor(), {}, rewriter); + rewriter, op->getLoc(), cast(op.getTensor().getType()), + adaptor.getTensor(), {}); Value basePtr = b.create(ti64, baseAddr); // Just use 14 bits for base address Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 400003d37bf20..f725993635672 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -289,8 +289,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern { // Resolve address. auto vtype = cast( this->typeConverter->convertType(loadOrStoreOp.getVectorType())); - Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), - adaptor.getIndices(), rewriter); + Value dataPtr = this->getStridedElementPtr( + rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices()); replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align, rewriter); return success(); @@ -337,8 +337,8 @@ class VectorGatherOpConversion return rewriter.notifyMatchFailure(gather, "could not resolve alignment"); // Resolve address. - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + Value ptr = getStridedElementPtr(rewriter, loc, memRefType, + adaptor.getBase(), adaptor.getIndices()); Value base = adaptor.getBase(); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, @@ -393,8 +393,8 @@ class VectorScatterOpConversion "could not resolve alignment"); // Resolve address. - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + Value ptr = getStridedElementPtr(rewriter, loc, memRefType, + adaptor.getBase(), adaptor.getIndices()); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(), ptr, adaptor.getIndexVec(), vType); @@ -428,8 +428,8 @@ class VectorExpandLoadOpConversion // Resolve address. auto vtype = typeConverter->convertType(expand.getVectorType()); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + Value ptr = getStridedElementPtr(rewriter, loc, memRefType, + adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOpWithNewOp( expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); @@ -450,8 +450,8 @@ class VectorCompressStoreOpConversion MemRefType memRefType = compress.getMemRefType(); // Resolve address. - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + Value ptr = getStridedElementPtr(rewriter, loc, memRefType, + adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOpWithNewOp( compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 4cb777b03b196..2168409184549 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -105,8 +105,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern { if (failed(stride)) return failure(); // Replace operation with intrinsic. - Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType, + adaptor.getBase(), adaptor.getIndices()); Type resType = typeConverter->convertType(tType); rewriter.replaceOpWithNewOp( op, resType, tsz.first, tsz.second, ptr, stride.value()); @@ -131,8 +131,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern { if (failed(stride)) return failure(); // Replace operation with intrinsic. - Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType, + adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOpWithNewOp( op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal()); return success(); diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir index 058b69b8e3596..3b52d8fd76464 100644 --- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir @@ -266,7 +266,7 @@ func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 : // CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1] // CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1] - // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]] + // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR]] // CHECK: llvm.store %{{.*}}, %[[STOREPTR]] memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32> @@ -295,12 +295,12 @@ func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : ind // CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1] // CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1] - // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]] + // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR]] // CHECK: llvm.store %{{.*}}, %[[STOREPTR]] memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32> // CHECK: %[[ALIGNEDPTR0:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1] - // CHECK: %[[LOADPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR0]] + // CHECK: %[[LOADPTR:.*]] = llvm.getelementptr inbounds|nuw %[[ALIGNEDPTR0]] // CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]] %0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32> diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir index be3ddc20c17b7..9ca8bcd1491bc 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -177,7 +177,7 @@ func.func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32 %0 = memref.load %mixed[%i, %j] : memref<42x?xf32> return @@ -194,7 +194,7 @@ func.func @dynamic_load(%dynamic : memref, %i : index, %j : index) { // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32 %0 = memref.load %dynamic[%i, %j] : memref return @@ -232,7 +232,7 @@ func.func @dynamic_store(%dynamic : memref, %i : index, %j : index, %va // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr memref.store %val, %dynamic[%i, %j] : memref return @@ -249,7 +249,7 @@ func.func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr memref.store %val, %mixed[%i, %j] : memref<42x?xf32> return diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir index 0a92c7cf7b216..b03ac2c20112b 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -140,7 +140,7 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { // CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64 // CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64 // CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64 -// CHECK: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: llvm.load %[[addr]] : !llvm.ptr -> f32 %0 = memref.load %static[%i, %j] : memref<10x42xf32> return @@ -168,7 +168,7 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va // CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64 // CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64 // CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64 -// CHECK: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr memref.store %val, %static[%i, %j] : memref<10x42xf32> @@ -307,7 +307,7 @@ func.func @memref.reshape.dynamic.dim(%arg: memref, %shape: memref<4x // CHECK: %[[three_hundred_and_eighty_four:.*]] = llvm.mlir.constant(384 : index) : i64 // CHECK: %[[one1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0]][%[[one1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 + // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr inbounds|nuw %[[shape_ptr0]][%[[one1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 // CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0]] : !llvm.ptr -> i64 // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[shape_load0]], %[[insert6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[three_hundred_and_eighty_four]], %[[insert7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> @@ -315,7 +315,7 @@ func.func @memref.reshape.dynamic.dim(%arg: memref, %shape: memref<4x // CHECK: %[[mul:.*]] = llvm.mul %19, %23 : i64 // CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: %[[shape_ptr1:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: %[[shape_gep1:.*]] = llvm.getelementptr %[[shape_ptr1]][%[[zero1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 + // CHECK: %[[shape_gep1:.*]] = llvm.getelementptr inbounds|nuw %[[shape_ptr1]][%[[zero1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 // CHECK: %[[shape_load1:.*]] = llvm.load %[[shape_gep1]] : !llvm.ptr -> i64 // CHECK: %[[insert9:.*]] = llvm.insertvalue %[[shape_load1]], %[[insert8]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> // CHECK: %[[insert10:.*]] = llvm.insertvalue %[[mul]], %[[insert9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> @@ -347,7 +347,7 @@ func.func @memref.reshape_index(%arg0: memref, %shape: memref<1xindex>) // CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 + // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr inbounds|nuw %[[shape_ptr0:.*]][%[[zero1:.*]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 // CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0:.*]] : !llvm.ptr -> i64 // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[shape_load0:.*]], %[[insert2:.*]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0:.*]], %[[insert3:.*]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index fe91d26d5a251..411abe6ac78d9 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -676,7 +676,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: llvm.intr.assume %{{.*}} ["align"(%[[BUFF_ADDR]], %{{.*}} : !llvm.ptr, i64)] : i1 -// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr inbounds|nuw %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32 // CHECK: return %[[VAL]] : f32 func.func @load_and_assume(