Skip to content

Conversation

@krzysz00
Copy link
Contributor

Now that LLVM allows a operand bundle on assume calls to directly
specify alignment assumptions, change the lowering of
memref.assume_alignment to use that feature instead of the ptrtoint
method.

This makes LLVM's job easier and prevents issues when dealing with
cases where ptrtoint isn't a desired operation (like those with poiner
provenance)

Copy link
Contributor Author

This stack of pull requests is managed by Graphite. Learn more about stacking.

@krzysz00 krzysz00 marked this pull request as ready for review November 26, 2024 21:41
@llvmbot llvmbot added the mlir label Nov 26, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

Now that LLVM allows a operand bundle on assume calls to directly
specify alignment assumptions, change the lowering of
memref.assume_alignment to use that feature instead of the ptrtoint
method.

This makes LLVM's job easier and prevents issues when dealing with
cases where ptrtoint isn't a desired operation (like those with poiner
provenance)


Full diff: https://github.com/llvm/llvm-project/pull/117800.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+10-16)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+1-4)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+8-14)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..c9cf3fe7a014e0 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -192,22 +192,16 @@ struct AssumeAlignmentOpLowering
     Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
                                      rewriter);
 
-    // Emit llvm.assume(memref & (alignment - 1) == 0).
-    //
-    // This relies on LLVM's CSE optimization (potentially after SROA), since
-    // after CSE all memref instances should get de-duplicated into the same
-    // pointer SSA value.
-    MemRefDescriptor memRefDescriptor(memref);
-    auto intPtrType =
-        getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
-    Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
-    Value mask =
-        createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
-    Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
-    rewriter.create<LLVM::AssumeOp>(
-        loc, rewriter.create<LLVM::ICmpOp>(
-                 loc, LLVM::ICmpPredicate::eq,
-                 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
+    // Emit llvm.assume(true) "align"(memref, alignment).
+    // This is more direct than ptrtoint-based checks, is explicitly supported,
+    // and works with non-integral address spaces.
+    Value trueCond =
+        rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
+    // LLVM docs always store this as an i32, follow the trend.
+    Value alignmentConst = createIndexAttrConstant(
+        rewriter, loc, rewriter.getI32Type(), alignment);
+    rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
+                                    alignmentConst);
 
     rewriter.eraseOp(op);
     return success();
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index ec5ceae57ccb33..f1aaa8bca6ff8f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -675,10 +675,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
 // CHECK: %[[ALIGNED_PTR:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
-// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}}  : i64
-// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64
-// CHECK: llvm.intr.assume %[[CMP]] : i1
+// CHECK: llvm.intr.assume %{{.*}} ["align"(%[[BUFF_ADDR]], %{{.*}} : !llvm.ptr, i32)] : i1
 // CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
 // CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
 // CHECK: return %[[VAL]] : f32
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 48dc9079333d4f..4a8a1b872d31cc 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -155,12 +155,9 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
 // CHECK-LABEL: func @assume_alignment(
 func.func @assume_alignment(%0 : memref<4x4xf16>) {
   // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-  // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
-  // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
-  // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64
-  // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
-  // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
-  // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
+  // CHECK-NEXT: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+  // CHECK-NEXT: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i32)] : i1
   memref.assume_alignment %0, 16 : memref<4x4xf16>
   return
 }
@@ -172,12 +169,9 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset
   // CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK-DAG: %[[BUFF_ADDR:.*]] =  llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16
-  // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
-  // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
-  // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
-  // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
-  // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
-  // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
+  // CHECK-DAG: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+  // CHECK-DAG: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[BUFF_ADDR]], %[[ALIGN]] : !llvm.ptr, i32)] : i1
   memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
   return
 }
@@ -410,7 +404,7 @@ func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>,
 // CHECK-SAME:   %[[ARG2:.+]]: index
 // CHECK-DAG:    %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK-DAG:    %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
-// CHECK:        %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+// CHECK:        %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK:        %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64
 // CHECK:        %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
 // CHECK:        %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
@@ -601,7 +595,7 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
 // CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
 func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
   %0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
-  // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> 
+  // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
   // CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
   // CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
   // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice clean up.
One nit below

loc, rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::eq,
rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
// Emit llvm.assume(true) "align"(memref, alignment).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The parentheses on assume looks wrong

Now that LLVM allows a operand bundle on assume calls to directly
specify alignment assumptions, change the lowering of
memref.assume_alignment to use that feature instead of the ptrtoint
method.

This makes LLVM's job easier and prevents issues when dealing with
cases where ptrtoint isn't a desired operation (like those with poiner
provenance)
@krzysz00 krzysz00 force-pushed the users/krzysz00/assume-operand-bundles branch from 618466b to 6a870c1 Compare November 26, 2024 22:56
@krzysz00 krzysz00 merged commit 3359806 into main Nov 26, 2024
5 of 7 checks passed
@krzysz00 krzysz00 deleted the users/krzysz00/assume-operand-bundles branch November 26, 2024 23:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants