Skip to content

Commit 7a513df

Browse files
committed
[mlir][memref] Introduce memref.distinct_objects op
1 parent 42b195e commit 7a513df

File tree

5 files changed

+130
-5
lines changed

5 files changed

+130
-5
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
153153
The `assume_alignment` operation takes a memref and an integer alignment
154154
value. It returns a new SSA value of the same memref type, but associated
155155
with the assumption that the underlying buffer is aligned to the given
156-
alignment.
156+
alignment.
157157

158158
If the buffer isn't aligned to the given alignment, its result is poison.
159159
This operation doesn't affect the semantics of a program where the
@@ -168,14 +168,49 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
168168
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
169169
let extraClassDeclaration = [{
170170
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
171-
171+
172172
Value getViewSource() { return getMemref(); }
173173
}];
174174

175175
let hasVerifier = 1;
176176
let hasFolder = 1;
177177
}
178178

179+
//===----------------------------------------------------------------------===//
180+
// DistinctObjectsOp
181+
//===----------------------------------------------------------------------===//
182+
183+
def DistinctObjectsOp : MemRef_Op<"distinct_objects", [
184+
Pure,
185+
DeclareOpInterfaceMethods<InferTypeOpInterface>
186+
// ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument
187+
]> {
188+
let summary = "assumption that acesses to specific memrefs will never alias";
189+
let description = [{
190+
The `distinct_objects` operation takes a list of memrefs and returns a list of
191+
memrefs of the same types, with the additional assumption that accesses to
192+
these memrefs will never alias with each other. This means that loads and
193+
stores to different memrefs in the list can be safely reordered.
194+
195+
If the memrefs do alias, the behavior is undefined. This operation doesn't
196+
affect the semantics of a program where the non-aliasing assumption holds
197+
true. It is intended for optimization purposes, allowing the compiler to
198+
generate more efficient code based on the non-aliasing assumption. The
199+
optimization is best-effort.
200+
201+
Example:
202+
203+
```mlir
204+
%1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32>
205+
```
206+
}];
207+
let arguments = (ins Variadic<AnyMemRef>:$operands);
208+
let results = (outs Variadic<AnyMemRef>:$results);
209+
210+
let assemblyFormat = "$operands attr-dict `:` type($operands)";
211+
let hasVerifier = 1;
212+
}
213+
179214
//===----------------------------------------------------------------------===//
180215
// AllocOp
181216
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,48 @@ struct AssumeAlignmentOpLowering
465465
}
466466
};
467467

468+
struct DistinctObjectsOpLowering
469+
: public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
470+
using ConvertOpToLLVMPattern<
471+
memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
472+
explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
473+
: ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {}
474+
475+
LogicalResult
476+
matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477+
ConversionPatternRewriter &rewriter) const override {
478+
ValueRange operands = adaptor.getOperands();
479+
if (operands.empty()) {
480+
rewriter.eraseOp(op);
481+
return success();
482+
}
483+
Location loc = op.getLoc();
484+
SmallVector<Value> ptrs;
485+
for (auto [origOperand, newOperand] :
486+
llvm::zip_equal(op.getOperands(), operands)) {
487+
auto memrefType = cast<MemRefType>(origOperand.getType());
488+
Value ptr = getStridedElementPtr(rewriter, loc, memrefType, newOperand,
489+
/*indices=*/{});
490+
ptrs.push_back(ptr);
491+
}
492+
493+
auto cond =
494+
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
495+
// Generate separate_storage assumptions for each pair of pointers.
496+
for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
497+
for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
498+
Value ptr1 = ptrs[i];
499+
Value ptr2 = ptrs[j];
500+
LLVM::AssumeOp::create(rewriter, loc, cond,
501+
LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
502+
}
503+
}
504+
505+
rewriter.replaceOp(op, operands);
506+
return success();
507+
}
508+
};
509+
468510
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
469511
// The memref descriptor being an SSA value, there is no need to clean it up
470512
// in any way.
@@ -1997,22 +2039,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
19972039
patterns.add<
19982040
AllocaOpLowering,
19992041
AllocaScopeOpLowering,
2000-
AtomicRMWOpLowering,
20012042
AssumeAlignmentOpLowering,
2043+
AtomicRMWOpLowering,
20022044
ConvertExtractAlignedPointerAsIndex,
20032045
DimOpLowering,
2046+
DistinctObjectsOpLowering,
20042047
ExtractStridedMetadataOpLowering,
20052048
GenericAtomicRMWOpLowering,
20062049
GetGlobalMemrefOpLowering,
20072050
LoadOpLowering,
20082051
MemRefCastOpLowering,
2009-
MemorySpaceCastOpLowering,
20102052
MemRefReinterpretCastOpLowering,
20112053
MemRefReshapeOpLowering,
2054+
MemorySpaceCastOpLowering,
20122055
PrefetchOpLowering,
20132056
RankOpLowering,
2014-
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
20152057
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2058+
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
20162059
StoreOpLowering,
20172060
SubViewOpLowering,
20182061
TransposeOpLowering,

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,25 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
542542
return getMemref();
543543
}
544544

545+
//===----------------------------------------------------------------------===//
546+
// DistinctObjectsOp
547+
//===----------------------------------------------------------------------===//
548+
549+
LogicalResult DistinctObjectsOp::verify() {
550+
if (getOperandTypes() != getResultTypes())
551+
return emitOpError("operand types and result types must match");
552+
return success();
553+
}
554+
555+
LogicalResult DistinctObjectsOp::inferReturnTypes(
556+
MLIRContext * /*context*/, std::optional<Location> /*location*/,
557+
ValueRange operands, DictionaryAttr /*attributes*/,
558+
OpaqueProperties /*properties*/, RegionRange /*regions*/,
559+
SmallVectorImpl<Type> &inferredReturnTypes) {
560+
llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
561+
return success();
562+
}
563+
545564
//===----------------------------------------------------------------------===//
546565
// CastOp
547566
//===----------------------------------------------------------------------===//

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,25 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
195195

196196
// -----
197197

198+
// ALL-LABEL: func @distinct_objects
199+
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
200+
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
201+
// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xf16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
202+
// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
203+
// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref<?xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
204+
// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
205+
// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
206+
// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
207+
// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
208+
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1
209+
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
210+
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
211+
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
212+
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
213+
}
214+
215+
// -----
216+
198217
// CHECK-LABEL: func @assume_alignment_w_offset
199218
// CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset
200219
func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
302302
return
303303
}
304304

305+
// CHECK-LABEL: func @distinct_objects
306+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
307+
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
308+
// CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<?xf16>, memref<?xf32>, memref<?xf64>
309+
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
310+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
311+
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
312+
}
313+
305314
// CHECK-LABEL: func @expand_collapse_shape_static
306315
func.func @expand_collapse_shape_static(
307316
%arg0: memref<3x4x5xf32>,

0 commit comments

Comments
 (0)