@@ -35,6 +35,9 @@ namespace mlir {
3535
3636using  namespace  mlir ; 
3737
38+ static  constexpr  LLVM::GEPNoWrapFlags kNoWrapFlags  =
39+     LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
40+ 
3841namespace  {
3942
4043static  bool  isStaticStrideOrOffset (int64_t  strideOrOffset) {
@@ -420,8 +423,8 @@ struct AssumeAlignmentOpLowering
420423    auto  loc = op.getLoc ();
421424
422425    auto  srcMemRefType = cast<MemRefType>(op.getMemref ().getType ());
423-     Value ptr = getStridedElementPtr (loc, srcMemRefType, memref,  /* indices= */ {} ,
424-                                      rewriter );
426+     Value ptr = getStridedElementPtr (rewriter,  loc, srcMemRefType, memref,
427+                                      /* indices= */ {} );
425428
426429    //  Emit llvm.assume(true) ["align"(memref, alignment)].
427430    //  This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -643,8 +646,8 @@ struct GenericAtomicRMWOpLowering
643646    //  Compute the loaded value and branch to the loop block.
644647    rewriter.setInsertionPointToEnd (initBlock);
645648    auto  memRefType = cast<MemRefType>(atomicOp.getMemref ().getType ());
646-     auto  dataPtr = getStridedElementPtr (loc, memRefType, adaptor. getMemref (), 
647-                                          adaptor.getIndices (), rewriter );
649+     auto  dataPtr = getStridedElementPtr (
650+         rewriter, loc, memRefType,  adaptor.getMemref (), adaptor. getIndices () );
648651    Value init = rewriter.create <LLVM::LoadOp>(
649652        loc, typeConverter->convertType (memRefType.getElementType ()), dataPtr);
650653    rewriter.create <LLVM::BrOp>(loc, init, loopBlock);
@@ -828,9 +831,12 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
828831                  ConversionPatternRewriter &rewriter) const  override  {
829832    auto  type = loadOp.getMemRefType ();
830833
831-     Value dataPtr =
832-         getStridedElementPtr (loadOp.getLoc (), type, adaptor.getMemref (),
833-                              adaptor.getIndices (), rewriter);
834+     //  Per memref.load spec, the indices must be in-bounds:
835+     //  0 <= idx < dim_size, and additionally all offsets are non-negative,
836+     //  hence inbounds and nuw are used when lowering to llvm.getelementptr.
837+     Value dataPtr = getStridedElementPtr (rewriter, loadOp.getLoc (), type,
838+                                          adaptor.getMemref (),
839+                                          adaptor.getIndices (), kNoWrapFlags );
834840    rewriter.replaceOpWithNewOp <LLVM::LoadOp>(
835841        loadOp, typeConverter->convertType (type.getElementType ()), dataPtr, 0 ,
836842        false , loadOp.getNontemporal ());
@@ -848,8 +854,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
848854                  ConversionPatternRewriter &rewriter) const  override  {
849855    auto  type = op.getMemRefType ();
850856
851-     Value dataPtr = getStridedElementPtr (op.getLoc (), type, adaptor.getMemref (),
852-                                          adaptor.getIndices (), rewriter);
857+     //  Per memref.store spec, the indices must be in-bounds:
858+     //  0 <= idx < dim_size, and additionally all offsets are non-negative,
859+     //  hence inbounds and nuw are used when lowering to llvm.getelementptr.
860+     Value dataPtr =
861+         getStridedElementPtr (rewriter, op.getLoc (), type, adaptor.getMemref (),
862+                              adaptor.getIndices (), kNoWrapFlags );
853863    rewriter.replaceOpWithNewOp <LLVM::StoreOp>(op, adaptor.getValue (), dataPtr,
854864                                               0 , false , op.getNontemporal ());
855865    return  success ();
@@ -867,8 +877,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
867877    auto  type = prefetchOp.getMemRefType ();
868878    auto  loc = prefetchOp.getLoc ();
869879
870-     Value dataPtr = getStridedElementPtr (loc, type, adaptor. getMemref (), 
871-                                           adaptor.getIndices (), rewriter );
880+     Value dataPtr = getStridedElementPtr (
881+         rewriter, loc, type,  adaptor.getMemref (), adaptor. getIndices () );
872882
873883    //  Replace with llvm.prefetch.
874884    IntegerAttr isWrite = rewriter.getI32IntegerAttr (prefetchOp.getIsWrite ());
@@ -1808,8 +1818,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
18081818    if  (failed (memRefType.getStridesAndOffset (strides, offset)))
18091819      return  failure ();
18101820    auto  dataPtr =
1811-         getStridedElementPtr (atomicOp.getLoc (), memRefType, adaptor. getMemref () ,
1812-                              adaptor.getIndices (), rewriter );
1821+         getStridedElementPtr (rewriter,  atomicOp.getLoc (), memRefType,
1822+                              adaptor.getMemref (), adaptor. getIndices () );
18131823    rewriter.replaceOpWithNewOp <LLVM::AtomicRMWOp>(
18141824        atomicOp, *maybeKind, dataPtr, adaptor.getValue (),
18151825        LLVM::AtomicOrdering::acq_rel);
0 commit comments