diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 4bfa536cc8a44..86f687d7f2636 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -192,22 +192,15 @@ struct AssumeAlignmentOpLowering Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, rewriter); - // Emit llvm.assume(memref & (alignment - 1) == 0). - // - // This relies on LLVM's CSE optimization (potentially after SROA), since - // after CSE all memref instances should get de-duplicated into the same - // pointer SSA value. - MemRefDescriptor memRefDescriptor(memref); - auto intPtrType = - getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); - Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); - Value mask = - createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); - Value ptrValue = rewriter.create(loc, intPtrType, ptr); - rewriter.create( - loc, rewriter.create( - loc, LLVM::ICmpPredicate::eq, - rewriter.create(loc, ptrValue, mask), zero)); + // Emit llvm.assume(true) ["align"(memref, alignment)]. + // This is more direct than ptrtoint-based checks, is explicitly supported, + // and works with non-integral address spaces. + Value trueCond = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value alignmentConst = + createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); + rewriter.create(loc, trueCond, LLVM::AssumeAlignTag(), ptr, + alignmentConst); rewriter.eraseOp(op); return success(); diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index ec5ceae57ccb3..a78db9733b7ee 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -675,10 +675,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf // CHECK: %[[ALIGNED_PTR:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 -// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64 -// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}} : i64 -// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64 -// CHECK: llvm.intr.assume %[[CMP]] : i1 +// CHECK: llvm.intr.assume %{{.*}} ["align"(%[[BUFF_ADDR]], %{{.*}} : !llvm.ptr, i64)] : i1 // CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32 // CHECK: return %[[VAL]] : f32 diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 48dc9079333d4..67b68b7a1c044 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -155,12 +155,9 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in // CHECK-LABEL: func @assume_alignment( func.func @assume_alignment(%0 : memref<4x4xf16>) { // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64 - // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64 - // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64 - // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64 - // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64 - // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1 + // CHECK-NEXT: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 + // CHECK-NEXT: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64 + // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1 memref.assume_alignment %0, 16 : memref<4x4xf16> return } @@ -172,12 +169,9 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset // CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16 - // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64 - // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64 - // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64 - // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64 - // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64 - // CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1 + // CHECK-DAG: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 + // CHECK-DAG: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64 + // CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[BUFF_ADDR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1 memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>> return } @@ -410,7 +404,7 @@ func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>, // CHECK-SAME: %[[ARG2:.+]]: index // CHECK-DAG: %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64 -// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64 // CHECK: %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK: %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 @@ -601,7 +595,7 @@ func.func @extract_aligned_pointer_as_index(%m: memref) -> index { // CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index { %0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index - // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> + // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> // CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr // CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64