-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][LLVM][MemRef] Lower assume_alignment with operand bundles #117800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesNow that LLVM allows a operand bundle on assume calls to directly This makes LLVM's job easier and prevents issues when dealing with Full diff: https://github.com/llvm/llvm-project/pull/117800.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..c9cf3fe7a014e0 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -192,22 +192,16 @@ struct AssumeAlignmentOpLowering
Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
rewriter);
- // Emit llvm.assume(memref & (alignment - 1) == 0).
- //
- // This relies on LLVM's CSE optimization (potentially after SROA), since
- // after CSE all memref instances should get de-duplicated into the same
- // pointer SSA value.
- MemRefDescriptor memRefDescriptor(memref);
- auto intPtrType =
- getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
- Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
- Value mask =
- createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
- Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
- rewriter.create<LLVM::AssumeOp>(
- loc, rewriter.create<LLVM::ICmpOp>(
- loc, LLVM::ICmpPredicate::eq,
- rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
+ // Emit llvm.assume(true) "align"(memref, alignment).
+ // This is more direct than ptrtoint-based checks, is explicitly supported,
+ // and works with non-integral address spaces.
+ Value trueCond =
+ rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
+ // LLVM docs always store this as an i32, follow the trend.
+ Value alignmentConst = createIndexAttrConstant(
+ rewriter, loc, rewriter.getI32Type(), alignment);
+ rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
+ alignmentConst);
rewriter.eraseOp(op);
return success();
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index ec5ceae57ccb33..f1aaa8bca6ff8f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -675,10 +675,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
-// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}} : i64
-// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64
-// CHECK: llvm.intr.assume %[[CMP]] : i1
+// CHECK: llvm.intr.assume %{{.*}} ["align"(%[[BUFF_ADDR]], %{{.*}} : !llvm.ptr, i32)] : i1
// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
// CHECK: return %[[VAL]] : f32
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 48dc9079333d4f..4a8a1b872d31cc 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -155,12 +155,9 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
// CHECK-LABEL: func @assume_alignment(
func.func @assume_alignment(%0 : memref<4x4xf16>) {
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
- // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
- // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64
- // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
- // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
- // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
+ // CHECK-NEXT: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+ // CHECK-NEXT: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i32)] : i1
memref.assume_alignment %0, 16 : memref<4x4xf16>
return
}
@@ -172,12 +169,9 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset
// CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16
- // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
- // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
- // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
- // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
- // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
- // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
+ // CHECK-DAG: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+ // CHECK-DAG: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[BUFF_ADDR]], %[[ALIGN]] : !llvm.ptr, i32)] : i1
memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
return
}
@@ -410,7 +404,7 @@ func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>,
// CHECK-SAME: %[[ARG2:.+]]: index
// CHECK-DAG: %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
-// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64
// CHECK: %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
// CHECK: %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
@@ -601,7 +595,7 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
%0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
- // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
+ // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
// CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
// CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64
|
qcolombet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice clean up.
One nit below
| loc, rewriter.create<LLVM::ICmpOp>( | ||
| loc, LLVM::ICmpPredicate::eq, | ||
| rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); | ||
| // Emit llvm.assume(true) "align"(memref, alignment). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: The parentheses on assume looks wrong
Now that LLVM allows a operand bundle on assume calls to directly specify alignment assumptions, change the lowering of memref.assume_alignment to use that feature instead of the ptrtoint method. This makes LLVM's job easier and prevents issues when dealing with cases where ptrtoint isn't a desired operation (like those with poiner provenance)
618466b to
6a870c1
Compare

Now that LLVM allows a operand bundle on assume calls to directly
specify alignment assumptions, change the lowering of
memref.assume_alignment to use that feature instead of the ptrtoint
method.
This makes LLVM's job easier and prevents issues when dealing with
cases where ptrtoint isn't a desired operation (like those with poiner
provenance)