@@ -35,6 +35,9 @@ namespace mlir {
3535
3636using namespace mlir ;
3737
38+ static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
39+ LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
40+
3841namespace {
3942
4043static bool isStaticStrideOrOffset (int64_t strideOrOffset) {
@@ -420,8 +423,8 @@ struct AssumeAlignmentOpLowering
420423 auto loc = op.getLoc ();
421424
422425 auto srcMemRefType = cast<MemRefType>(op.getMemref ().getType ());
423- Value ptr = getStridedElementPtr (loc, srcMemRefType, memref, /* indices= */ {} ,
424- rewriter );
426+ Value ptr = getStridedElementPtr (rewriter, loc, srcMemRefType, memref,
427+ /* indices= */ {} );
425428
426429 // Emit llvm.assume(true) ["align"(memref, alignment)].
427430 // This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -643,8 +646,8 @@ struct GenericAtomicRMWOpLowering
643646 // Compute the loaded value and branch to the loop block.
644647 rewriter.setInsertionPointToEnd (initBlock);
645648 auto memRefType = cast<MemRefType>(atomicOp.getMemref ().getType ());
646- auto dataPtr = getStridedElementPtr (loc, memRefType, adaptor. getMemref (),
647- adaptor.getIndices (), rewriter );
649+ auto dataPtr = getStridedElementPtr (
650+ rewriter, loc, memRefType, adaptor.getMemref (), adaptor. getIndices () );
648651 Value init = rewriter.create <LLVM::LoadOp>(
649652 loc, typeConverter->convertType (memRefType.getElementType ()), dataPtr);
650653 rewriter.create <LLVM::BrOp>(loc, init, loopBlock);
@@ -828,9 +831,12 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
828831 ConversionPatternRewriter &rewriter) const override {
829832 auto type = loadOp.getMemRefType ();
830833
831- Value dataPtr =
832- getStridedElementPtr (loadOp.getLoc (), type, adaptor.getMemref (),
833- adaptor.getIndices (), rewriter);
834+ // Per memref.load spec, the indices must be in-bounds:
835+ // 0 <= idx < dim_size, and additionally all offsets are non-negative,
836+ // hence inbounds and nuw are used when lowering to llvm.getelementptr.
837+ Value dataPtr = getStridedElementPtr (rewriter, loadOp.getLoc (), type,
838+ adaptor.getMemref (),
839+ adaptor.getIndices (), kNoWrapFlags );
834840 rewriter.replaceOpWithNewOp <LLVM::LoadOp>(
835841 loadOp, typeConverter->convertType (type.getElementType ()), dataPtr, 0 ,
836842 false , loadOp.getNontemporal ());
@@ -848,8 +854,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
848854 ConversionPatternRewriter &rewriter) const override {
849855 auto type = op.getMemRefType ();
850856
851- Value dataPtr = getStridedElementPtr (op.getLoc (), type, adaptor.getMemref (),
852- adaptor.getIndices (), rewriter);
857+ // Per memref.store spec, the indices must be in-bounds:
858+ // 0 <= idx < dim_size, and additionally all offsets are non-negative,
859+ // hence inbounds and nuw are used when lowering to llvm.getelementptr.
860+ Value dataPtr =
861+ getStridedElementPtr (rewriter, op.getLoc (), type, adaptor.getMemref (),
862+ adaptor.getIndices (), kNoWrapFlags );
853863 rewriter.replaceOpWithNewOp <LLVM::StoreOp>(op, adaptor.getValue (), dataPtr,
854864 0 , false , op.getNontemporal ());
855865 return success ();
@@ -867,8 +877,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
867877 auto type = prefetchOp.getMemRefType ();
868878 auto loc = prefetchOp.getLoc ();
869879
870- Value dataPtr = getStridedElementPtr (loc, type, adaptor. getMemref (),
871- adaptor.getIndices (), rewriter );
880+ Value dataPtr = getStridedElementPtr (
881+ rewriter, loc, type, adaptor.getMemref (), adaptor. getIndices () );
872882
873883 // Replace with llvm.prefetch.
874884 IntegerAttr isWrite = rewriter.getI32IntegerAttr (prefetchOp.getIsWrite ());
@@ -1808,8 +1818,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
18081818 if (failed (memRefType.getStridesAndOffset (strides, offset)))
18091819 return failure ();
18101820 auto dataPtr =
1811- getStridedElementPtr (atomicOp.getLoc (), memRefType, adaptor. getMemref () ,
1812- adaptor.getIndices (), rewriter );
1821+ getStridedElementPtr (rewriter, atomicOp.getLoc (), memRefType,
1822+ adaptor.getMemref (), adaptor. getIndices () );
18131823 rewriter.replaceOpWithNewOp <LLVM::AtomicRMWOp>(
18141824 atomicOp, *maybeKind, dataPtr, adaptor.getValue (),
18151825 LLVM::AtomicOrdering::acq_rel);
0 commit comments