diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 8da850678878d..48fbcbcdbbde9 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -73,6 +73,15 @@ Value ConvertToLLVMPattern::getStridedElementPtr( Value base = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type); + LLVM::IntegerOverflowFlags intOverflowFlags = + LLVM::IntegerOverflowFlags::none; + if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) { + intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw; + } + if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) { + intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw; + } + Type indexType = getIndexType(); Value index; for (int i = 0, e = indices.size(); i < e; ++i) { @@ -82,10 +91,12 @@ Value ConvertToLLVMPattern::getStridedElementPtr( ShapedType::isDynamic(strides[i]) ? memRefDescriptor.stride(rewriter, loc, i) : createIndexAttrConstant(rewriter, loc, indexType, strides[i]); - increment = rewriter.create(loc, increment, stride); + increment = rewriter.create(loc, increment, stride, + intOverflowFlags); } - index = - index ? rewriter.create(loc, index, increment) : increment; + index = index ? rewriter.create(loc, index, increment, + intOverflowFlags) + : increment; } Type elementPtrType = memRefDescriptor.getElementPtrType(); diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir index 9ca8bcd1491bc..543fdf5c26f5e 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -175,8 +175,8 @@ func.func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { // CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow : i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow : i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32 %0 = memref.load %mixed[%i, %j] : memref<42x?xf32> @@ -192,8 +192,8 @@ func.func @dynamic_load(%dynamic : memref, %i : index, %j : index) { // CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow : i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow : i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32 %0 = memref.load %dynamic[%i, %j] : memref @@ -230,8 +230,8 @@ func.func @dynamic_store(%dynamic : memref, %i : index, %j : index, %va // CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow : i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow : i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr memref.store %val, %dynamic[%i, %j] : memref @@ -247,8 +247,8 @@ func.func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : // CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow : i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow : i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr memref.store %val, %mixed[%i, %j] : memref<42x?xf32> diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir index b03ac2c20112b..040a27e160557 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -138,8 +138,8 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { // CHECK-DAG: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64 -// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64 -// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64 +// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] overflow : i64 +// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] overflow : i64 // CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: llvm.load %[[addr]] : !llvm.ptr -> f32 %0 = memref.load %static[%i, %j] : memref<10x42xf32> @@ -166,8 +166,8 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va // CHECK-DAG: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64 -// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64 -// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64 +// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] overflow : i64 +// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] overflow : i64 // CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr