@@ -35,6 +35,9 @@ namespace mlir {
3535
3636using  namespace  mlir ; 
3737
38+ static  constexpr  LLVM::GEPNoWrapFlags kNoWrapFlags  =
39+     LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
40+ 
3841namespace  {
3942
4043static  bool  isStaticStrideOrOffset (int64_t  strideOrOffset) {
@@ -420,8 +423,8 @@ struct AssumeAlignmentOpLowering
420423    auto  loc = op.getLoc ();
421424
422425    auto  srcMemRefType = cast<MemRefType>(op.getMemref ().getType ());
423-     Value ptr = getStridedElementPtr (loc, srcMemRefType, memref,  /* indices= */ {} ,
424-                                      rewriter );
426+     Value ptr = getStridedElementPtr (rewriter,  loc, srcMemRefType, memref,
427+                                      /* indices= */ {} );
425428
426429    //  Emit llvm.assume(true) ["align"(memref, alignment)].
427430    //  This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -644,8 +647,8 @@ struct GenericAtomicRMWOpLowering
644647    //  Compute the loaded value and branch to the loop block.
645648    rewriter.setInsertionPointToEnd (initBlock);
646649    auto  memRefType = cast<MemRefType>(atomicOp.getMemref ().getType ());
647-     auto  dataPtr = getStridedElementPtr (loc, memRefType, adaptor. getMemref (), 
648-                                          adaptor.getIndices (), rewriter );
650+     auto  dataPtr = getStridedElementPtr (
651+         rewriter, loc, memRefType,  adaptor.getMemref (), adaptor. getIndices () );
649652    Value init = rewriter.create <LLVM::LoadOp>(
650653        loc, typeConverter->convertType (memRefType.getElementType ()), dataPtr);
651654    rewriter.create <LLVM::BrOp>(loc, init, loopBlock);
@@ -829,9 +832,12 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
829832                  ConversionPatternRewriter &rewriter) const  override  {
830833    auto  type = loadOp.getMemRefType ();
831834
832-     Value dataPtr =
833-         getStridedElementPtr (loadOp.getLoc (), type, adaptor.getMemref (),
834-                              adaptor.getIndices (), rewriter);
835+     //  Per memref.load spec, the indices must be in-bounds:
836+     //  0 <= idx < dim_size, and additionally all offsets are non-negative,
837+     //  hence inbounds and nuw are used when lowering to llvm.getelementptr.
838+     Value dataPtr = getStridedElementPtr (rewriter, loadOp.getLoc (), type,
839+                                          adaptor.getMemref (),
840+                                          adaptor.getIndices (), kNoWrapFlags );
835841    rewriter.replaceOpWithNewOp <LLVM::LoadOp>(
836842        loadOp, typeConverter->convertType (type.getElementType ()), dataPtr, 0 ,
837843        false , loadOp.getNontemporal ());
@@ -849,8 +855,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
849855                  ConversionPatternRewriter &rewriter) const  override  {
850856    auto  type = op.getMemRefType ();
851857
852-     Value dataPtr = getStridedElementPtr (op.getLoc (), type, adaptor.getMemref (),
853-                                          adaptor.getIndices (), rewriter);
858+     //  Per memref.store spec, the indices must be in-bounds:
859+     //  0 <= idx < dim_size, and additionally all offsets are non-negative,
860+     //  hence inbounds and nuw are used when lowering to llvm.getelementptr.
861+     Value dataPtr =
862+         getStridedElementPtr (rewriter, op.getLoc (), type, adaptor.getMemref (),
863+                              adaptor.getIndices (), kNoWrapFlags );
854864    rewriter.replaceOpWithNewOp <LLVM::StoreOp>(op, adaptor.getValue (), dataPtr,
855865                                               0 , false , op.getNontemporal ());
856866    return  success ();
@@ -868,8 +878,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
868878    auto  type = prefetchOp.getMemRefType ();
869879    auto  loc = prefetchOp.getLoc ();
870880
871-     Value dataPtr = getStridedElementPtr (loc, type, adaptor. getMemref (), 
872-                                           adaptor.getIndices (), rewriter );
881+     Value dataPtr = getStridedElementPtr (
882+         rewriter, loc, type,  adaptor.getMemref (), adaptor. getIndices () );
873883
874884    //  Replace with llvm.prefetch.
875885    IntegerAttr isWrite = rewriter.getI32IntegerAttr (prefetchOp.getIsWrite ());
@@ -1809,8 +1819,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
18091819    if  (failed (memRefType.getStridesAndOffset (strides, offset)))
18101820      return  failure ();
18111821    auto  dataPtr =
1812-         getStridedElementPtr (atomicOp.getLoc (), memRefType, adaptor. getMemref () ,
1813-                              adaptor.getIndices (), rewriter );
1822+         getStridedElementPtr (rewriter,  atomicOp.getLoc (), memRefType,
1823+                              adaptor.getMemref (), adaptor. getIndices () );
18141824    rewriter.replaceOpWithNewOp <LLVM::AtomicRMWOp>(
18151825        atomicOp, *maybeKind, dataPtr, adaptor.getValue (),
18161826        LLVM::AtomicOrdering::acq_rel);
0 commit comments