Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,22 +142,34 @@ class AllocLikeOp<string mnemonic,
// AssumeAlignmentOp
//===----------------------------------------------------------------------===//

def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
NoMemoryEffect,
ViewLikeOpInterface,
SameOperandsAndResultType
]> {
let summary =
"assertion that gives alignment information to the input memref";
let description = [{
The `assume_alignment` operation takes a memref and an integer of alignment
value, and internally annotates the buffer with the given alignment. If
the buffer isn't aligned to the given alignment, the behavior is undefined.
The `assume_alignment` operation takes a memref and an integer of alignment
value. It returns a new SSA value of the same memref type, but associated
with the assertion that the underlying buffer is aligned to the given
alignment. If the buffer isn't aligned to the given alignment, the
behavior is undefined.

This operation doesn't affect the semantics of a correct program. It's for
optimization only, and the optimization is best-effort.
This operation doesn't affect the semantics of a correct program. It's for
optimization only, and the optimization is best-effort.
}];
let arguments = (ins AnyMemRef:$memref,
ConfinedAttr<I32Attr, [IntPositive]>:$alignment);
let results = (outs);
let results = (outs AnyMemRef:$result);

let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
let extraClassDeclaration = [{
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }

Value getViewSource() { return getMemref(); }
}];

let hasVerifier = 1;
}

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,7 @@ struct AssumeAlignmentOpLowering
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
alignmentConst);

rewriter.eraseOp(op);
rewriter.replaceOp(op, memref);
return success();
}
};
Expand Down
10 changes: 1 addition & 9 deletions mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ using namespace mlir::gpu;
// The functions below provide interface-like verification, but are too specific
// to barrier elimination to become interfaces.

/// Implement the MemoryEffectsOpInterface in the suitable way.
static bool isKnownNoEffectsOpWithoutInterface(Operation *op) {
// memref::AssumeAlignment is conceptually pure, but marking it as such would
// make DCE immediately remove it.
return isa<memref::AssumeAlignmentOp>(op);
}


/// Returns `true` if the op is defines the parallel region that is subject to
/// barrier synchronization.
Expand Down Expand Up @@ -101,9 +96,6 @@ collectEffects(Operation *op,
if (ignoreBarriers && isa<BarrierOp>(op))
return true;

// Skip over ops that we know have no effects.
if (isKnownNoEffectsOpWithoutInterface(op))
return true;

// Collect effect instances the operation. Note that the implementation of
// getEffects erases all effect instances that have the type other than the
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final
}

rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
op, adaptor.getMemref(), adaptor.getAlignmentAttr());
op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
return success();
}
};
Expand Down
30 changes: 30 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,34 @@ struct ExtractStridedMetadataOpGetGlobalFolder
}
};

/// Pattern to replace `extract_strided_metadata(assume_alignment)`
///
/// With
/// \verbatim
/// extract_strided_metadata(memref)
/// \endverbatim
///
/// Since `assume_alignment` is a view-like op that does not modify the
/// underlying buffer, offset, sizes, or strides, extracting strided metadata
/// from its result is equivalent to extracting it from its source. This
/// canonicalization removes the unnecessary indirection.
struct ExtractStridedMetadataOpAssumeAlignmentFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
public:
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;

LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto assumeAlignmentOp =
op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
if (!assumeAlignmentOp)
return failure();

rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(op, assumeAlignmentOp.getViewSource());
return success();
}
};

/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
/// source of the ViewLikeOp.
class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
Expand Down Expand Up @@ -1185,6 +1213,7 @@ void memref::populateExpandStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand All @@ -1201,6 +1230,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ func.func @load_and_assume(
%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
%i0: index, %i1: index)
-> f32 {
memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
%arg0_align = memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%2 = memref.load %arg0_align[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
func.return %2 : f32
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,11 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
%0 = arith.cmpi slt, %arg0, %arg1 : index
cf.assert %0, "%arg0 must be less than %arg1"
return
}

// CHECK-LABEL: func @func_with_assume_alignment(
// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
%0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
return
}
28 changes: 14 additions & 14 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 {

func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
%0 = memref.alloc() : memref<3x125xi4>
memref.assume_alignment %0, 64 : memref<3x125xi4>
%1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
%align0 =memref.assume_alignment %0, 64 : memref<3x125xi4>
%1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4>
return %1 : i4
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
Expand All @@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
Expand All @@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
Expand Down Expand Up @@ -350,16 +350,16 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {

func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
%0 = memref.alloc() : memref<3x125xi4>
memref.assume_alignment %0, 64 : memref<3x125xi4>
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4>
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
// CHECK: func @memref_store_i4_rank2(
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
Expand All @@ -369,16 +369,16 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: return

// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
// CHECK32: func @memref_store_i4_rank2(
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
Expand All @@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: return

// -----
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ func.func @invalid_memref_cast() {
// alignment is not power of 2.
func.func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{alignment must be power of 2}}
memref.assume_alignment %0, 12 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 12 : memref<4x4xf16>
return
}

Expand All @@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
// 0 alignment value.
func.func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
memref.assume_alignment %0, 0 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 0 : memref<4x4xf16>
return
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/MemRef/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
func.func @assume_alignment(%0: memref<4x4xf16>) {
// CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
memref.assume_alignment %0, 16 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 16 : memref<4x4xf16>
return
}

Expand Down