diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index f34b5b46cab50..54ac899f96f06 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -142,22 +142,37 @@ class AllocLikeOp { +def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ + DeclareOpInterfaceMethods, + Pure, + ViewLikeOpInterface, + SameOperandsAndResultType + ]> { let summary = - "assertion that gives alignment information to the input memref"; + "assumption that gives alignment information to the input memref"; let description = [{ - The `assume_alignment` operation takes a memref and an integer of alignment - value, and internally annotates the buffer with the given alignment. If - the buffer isn't aligned to the given alignment, the behavior is undefined. + The `assume_alignment` operation takes a memref and an integer alignment + value. It returns a new SSA value of the same memref type, but associated + with the assumption that the underlying buffer is aligned to the given + alignment. - This operation doesn't affect the semantics of a correct program. It's for - optimization only, and the optimization is best-effort. + If the buffer isn't aligned to the given alignment, its result is poison. + This operation doesn't affect the semantics of a program where the + alignment assumption holds true. It is intended for optimization purposes, + allowing the compiler to generate more efficient code based on the + alignment assumption. The optimization is best-effort. }]; let arguments = (ins AnyMemRef:$memref, ConfinedAttr:$alignment); - let results = (outs); + let results = (outs AnyMemRef:$result); let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; + let extraClassDeclaration = [{ + MemRefType getType() { return ::llvm::cast(getResult().getType()); } + + Value getViewSource() { return getMemref(); } + }]; + let hasVerifier = 1; } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 158de6dea58c9..7f45904fab7e1 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -432,8 +432,7 @@ struct AssumeAlignmentOpLowering createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); rewriter.create(loc, trueCond, LLVM::AssumeAlignTag(), ptr, alignmentConst); - - rewriter.eraseOp(op); + rewriter.replaceOp(op, memref); return success(); } }; diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index 84c12c0ba05e5..912d4dc99a885 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -44,13 +44,6 @@ using namespace mlir::gpu; // The functions below provide interface-like verification, but are too specific // to barrier elimination to become interfaces. -/// Implement the MemoryEffectsOpInterface in the suitable way. -static bool isKnownNoEffectsOpWithoutInterface(Operation *op) { - // memref::AssumeAlignment is conceptually pure, but marking it as such would - // make DCE immediately remove it. - return isa(op); -} - /// Returns `true` if the op is defines the parallel region that is subject to /// barrier synchronization. static bool isParallelRegionBoundary(Operation *op) { @@ -101,10 +94,6 @@ collectEffects(Operation *op, if (ignoreBarriers && isa(op)) return true; - // Skip over ops that we know have no effects. - if (isKnownNoEffectsOpWithoutInterface(op)) - return true; - // Collect effect instances the operation. Note that the implementation of // getEffects erases all effect instances that have the type other than the // template parameter so we collect them first in a local buffer and then diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index a0237c18cf2fe..82702789c2913 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -527,6 +527,11 @@ LogicalResult AssumeAlignmentOp::verify() { return success(); } +void AssumeAlignmentOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "assume_align"); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 59cfce28e07e1..d2a032688fb6d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final } rewriter.replaceOpWithNewOp( - op, adaptor.getMemref(), adaptor.getAlignmentAttr()); + op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr()); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index e9a80be87a0f7..cfd529c46a41d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -919,6 +919,35 @@ struct ExtractStridedMetadataOpGetGlobalFolder } }; +/// Pattern to replace `extract_strided_metadata(assume_alignment)` +/// +/// With +/// \verbatim +/// extract_strided_metadata(memref) +/// \endverbatim +/// +/// Since `assume_alignment` is a view-like op that does not modify the +/// underlying buffer, offset, sizes, or strides, extracting strided metadata +/// from its result is equivalent to extracting it from its source. This +/// canonicalization removes the unnecessary indirection. +struct ExtractStridedMetadataOpAssumeAlignmentFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto assumeAlignmentOp = + op.getSource().getDefiningOp(); + if (!assumeAlignmentOp) + return failure(); + + rewriter.replaceOpWithNewOp( + op, assumeAlignmentOp.getViewSource()); + return success(); + } +}; + /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the /// source of the ViewLikeOp. class RewriteExtractAlignedPointerAsIndexOfViewLikeOp @@ -1185,6 +1214,7 @@ void memref::populateExpandStridedMetadataPatterns( ExtractStridedMetadataOpSubviewFolder, ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, + ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } @@ -1201,6 +1231,7 @@ void memref::populateResolveExtractStridedMetadataPatterns( ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, + ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index fe91d26d5a251..8dd7edf3e29b1 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -683,7 +683,7 @@ func.func @load_and_assume( %arg0: memref>, %i0: index, %i1: index) -> f32 { - memref.assume_alignment %arg0, 16 : memref> - %2 = memref.load %arg0[%i0, %i1] : memref> + %arg0_align = memref.assume_alignment %arg0, 16 : memref> + %2 = memref.load %arg0_align[%i0, %i1] : memref> func.return %2 : f32 } diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir index 05e52848ca877..c50c25ad8194f 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir @@ -10,4 +10,11 @@ func.func @func_with_assert(%arg0: index, %arg1: index) { %0 = arith.cmpi slt, %arg0, %arg1 : index cf.assert %0, "%arg0 must be less than %arg1" return +} + +// CHECK-LABEL: func @func_with_assume_alignment( +// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8> +func.func @func_with_assume_alignment(%arg0: memref<128xi8>) { + %0 = memref.assume_alignment %arg0, 64 : memref<128xi8> + return } \ No newline at end of file diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index 0cb3b7b744476..111a02abcc74c 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 { func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 { %0 = memref.alloc() : memref<3x125xi4> - memref.assume_alignment %0, 64 : memref<3x125xi4> - %1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4> + %align0 =memref.assume_alignment %0, 64 : memref<3x125xi4> + %1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4> return %1 : i4 } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)> @@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8> -// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8> +// CHECK: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8> // CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] -// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]] // CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]] // CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 // CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] @@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 { // CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index // CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32> -// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32> +// CHECK32: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32> // CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] -// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK32: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]] // CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]] // CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 // CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] @@ -350,8 +350,8 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () { func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () { %0 = memref.alloc() : memref<3x125xi4> - memref.assume_alignment %0, 64 : memref<3x125xi4> - memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4> + %align0 = memref.assume_alignment %0, 64 : memref<3x125xi4> + memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4> return } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)> @@ -359,7 +359,7 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () { // CHECK: func @memref_store_i4_rank2( // CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4 // CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8> -// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8> +// CHECK-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8> // CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8 // CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] // CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]] @@ -369,8 +369,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () { // CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8 // CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8 // CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8 -// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8 -// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8 +// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8 +// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8 // CHECK: return // CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)> @@ -378,7 +378,7 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () { // CHECK32: func @memref_store_i4_rank2( // CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4 // CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32> -// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32> +// CHECK32-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32> // CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32 // CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] // CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]] @@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () { // CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32 // CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32 // CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32 -// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32 -// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32 +// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32 +// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32 // CHECK32: return // ----- diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 34fc4775924e7..f908efb638446 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -878,7 +878,7 @@ func.func @invalid_memref_cast() { // alignment is not power of 2. func.func @assume_alignment(%0: memref<4x4xf16>) { // expected-error@+1 {{alignment must be power of 2}} - memref.assume_alignment %0, 12 : memref<4x4xf16> + %1 = memref.assume_alignment %0, 12 : memref<4x4xf16> return } @@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) { // 0 alignment value. func.func @assume_alignment(%0: memref<4x4xf16>) { // expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} - memref.assume_alignment %0, 0 : memref<4x4xf16> + %1 = memref.assume_alignment %0, 0 : memref<4x4xf16> return } diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 7038a6ff744e4..38ee363a7d424 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16> func.func @assume_alignment(%0: memref<4x4xf16>) { // CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16> - memref.assume_alignment %0, 16 : memref<4x4xf16> + %1 = memref.assume_alignment %0, 16 : memref<4x4xf16> return }