Skip to content

Commit 3c0f7b1

Browse files
Jimmy2027nicolasvasilacheftynse
authored
[mlir][transform] Add PromoteTensorOp (#158318)
Transform op to request a tensor value to live in a specific memory space after bufferization Co-authored-by: Nicolas Vasilache <[email protected]> Co-authored-by: Alex Zinenko <[email protected]>
1 parent 332b4de commit 3c0f7b1

File tree

4 files changed

+239
-36
lines changed

4 files changed

+239
-36
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
1717
include "mlir/Dialect/Transform/IR/TransformTypes.td"
1818
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
1919
include "mlir/Interfaces/SideEffectInterfaces.td"
20+
include "mlir/Interfaces/InferTypeOpInterface.td"
2021
include "mlir/IR/OpBase.td"
2122
include "mlir/IR/RegionKindInterface.td"
2223

@@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
236237
Transform_AnyOpType:$new_ops);
237238
let assemblyFormat = "$target attr-dict `:` type($target)";
238239
let hasVerifier = 1;
240+
}
239241

240-
let builders = [
241-
OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
242-
OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
243-
];
242+
//===----------------------------------------------------------------------===//
243+
// PromoteTensorOp
244+
//===----------------------------------------------------------------------===//
245+
246+
def PromoteTensorOp : Op<Transform_Dialect, "structured.promote_tensor",
247+
[DeclareOpInterfaceMethods<TransformOpInterface>,
248+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
249+
SameOperandsAndResultType]> {
250+
let summary = "Request a tensor value to live in a specific memory space "
251+
"after bufferization";
252+
let description = [{
253+
Requests that a tensor value lives in a specific memory space for its
254+
lifetime. This is achieved by allocating a new tensor in the desired
255+
memory space with `bufferization.alloc_tensor` and optionally materializing
256+
the source value into that allocation with
257+
`bufferization.materialize_in_destination`. All uses of the original value
258+
are then redirected to the promoted value.
259+
260+
The generated code for promoting tensor value %0 resembles the following:
261+
262+
%1 = bufferization.alloc_tensor(<dynamic dims of %0>)
263+
{ memory_space = memory_space }
264+
// Note: the materialization is omitted if %0 is never read and is only
265+
// written into (i.e., it behaves as a result tensor).
266+
%2 = bufferization.materialize_in_destination %0 in %1
267+
// ...
268+
<all users of %0 now use %2 instead>
269+
270+
Deallocation is not handled by this transform.
271+
272+
Return modes:
273+
- Produces a silenceable failure if the given handle does not point to
274+
tensor-typed values.
275+
- Succeeds otherwise and returns a handle to the promoted value(s), i.e.,
276+
the result of materialization if present and the allocation otherwise.
277+
}];
278+
279+
let arguments = (ins TransformValueHandleTypeInterface:$tensor,
280+
OptionalAttr<AnyAttr>:$memory_space);
281+
let results = (outs TransformValueHandleTypeInterface:$promoted);
282+
283+
let assemblyFormat =
284+
"(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)";
244285
}
245286

246287
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4343
#include "llvm/ADT/STLExtras.h"
4444
#include "llvm/ADT/ScopeExit.h"
45+
#include "llvm/ADT/SmallPtrSet.h"
4546
#include "llvm/ADT/TypeSwitch.h"
4647
#include "llvm/Support/DebugLog.h"
4748
#include "llvm/Support/LogicalResult.h"
@@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
273274
// BufferizeToAllocationOp
274275
//===----------------------------------------------------------------------===//
275276

276-
void transform::BufferizeToAllocationOp::build(OpBuilder &b,
277-
OperationState &result,
278-
Value target,
279-
Attribute memorySpace) {
280-
SmallVector<Type> resultTypes;
281-
resultTypes.push_back(b.getType<transform::AnyValueType>());
282-
resultTypes.push_back(b.getType<transform::AnyOpType>());
283-
return build(b, result,
284-
/*resultTypes=*/resultTypes,
285-
/*target=*/target,
286-
/*memory_space=*/memorySpace);
287-
}
288-
289-
void transform::BufferizeToAllocationOp::build(OpBuilder &b,
290-
OperationState &result,
291-
Value target,
292-
int64_t memorySpace) {
293-
SmallVector<Type> resultTypes;
294-
resultTypes.push_back(b.getType<transform::AnyValueType>());
295-
resultTypes.push_back(b.getType<transform::AnyOpType>());
296-
return build(b, result,
297-
/*resultTypes=*/resultTypes,
298-
/*target=*/target,
299-
/*memory_space=*/b.getI64IntegerAttr(memorySpace));
300-
}
301-
302277
namespace {
303278
class NewOpsListener : public RewriterBase::ForwardingListener {
304279
public:
@@ -408,6 +383,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
408383
return success();
409384
}
410385

386+
//===----------------------------------------------------------------------===//
387+
// PromoteTensorOp
388+
//===----------------------------------------------------------------------===//
389+
390+
/// Return true if the operand may be read from by its owner. This is currently
391+
/// very conservative and only looks inside linalg operations to prevent
392+
/// unintentional data loss.
393+
static bool mayBeRead(OpOperand &operand) {
394+
auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
395+
396+
// Be conservative about ops we cannot analyze deeper.
397+
if (!linalgOp)
398+
return true;
399+
400+
// Look inside linalg ops.
401+
Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
402+
return !blockArgument.use_empty();
403+
}
404+
405+
/// Return true if the value may be read through any of its uses.
406+
static bool mayBeRead(Value value) {
407+
// If the value has a reference semantics, it
408+
// may be read through any alias...
409+
if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
410+
return true;
411+
return llvm::any_of(value.getUses(),
412+
static_cast<bool (&)(OpOperand &)>(mayBeRead));
413+
}
414+
415+
DiagnosedSilenceableFailure
416+
transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
417+
transform::TransformResults &results,
418+
transform::TransformState &state) {
419+
SmallVector<Value> promoted;
420+
for (Value tensor : state.getPayloadValues(getTensor())) {
421+
auto type = dyn_cast<RankedTensorType>(tensor.getType());
422+
if (!type) {
423+
return emitSilenceableError() << "non-tensor type: " << tensor;
424+
}
425+
426+
Operation *definingOp = tensor.getDefiningOp();
427+
if (definingOp)
428+
rewriter.setInsertionPointAfter(definingOp);
429+
else
430+
rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
431+
432+
// Check this before we emit operations using this value.
433+
bool needsMaterialization = mayBeRead(tensor);
434+
435+
SmallVector<Value> dynamicDims;
436+
llvm::SmallPtrSet<Operation *, 4> preservedOps;
437+
for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
438+
if (!ShapedType::isDynamic(dim))
439+
continue;
440+
Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
441+
auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
442+
preservedOps.insert(dimOp);
443+
dynamicDims.push_back(dimOp);
444+
}
445+
auto allocation = rewriter.create<bufferization::AllocTensorOp>(
446+
tensor.getLoc(), type, dynamicDims);
447+
// Set memory space if provided.
448+
if (getMemorySpaceAttr())
449+
allocation.setMemorySpaceAttr(getMemorySpaceAttr());
450+
Value allocated = allocation;
451+
452+
// Only insert a materialization (typically bufferizes to a copy) when the
453+
// value may be read from.
454+
if (needsMaterialization) {
455+
auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
456+
tensor.getLoc(), tensor, allocated);
457+
preservedOps.insert(copy);
458+
promoted.push_back(copy.getResult());
459+
} else {
460+
promoted.push_back(allocated);
461+
}
462+
rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
463+
}
464+
results.setValues(cast<OpResult>(getPromoted()), promoted);
465+
return DiagnosedSilenceableFailure::success();
466+
}
467+
468+
void transform::PromoteTensorOp::getEffects(
469+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
470+
transform::onlyReadsHandle(getTensorMutable(), effects);
471+
transform::producesHandle(getOperation()->getOpResults(), effects);
472+
transform::modifiesPayload(effects);
473+
}
474+
411475
//===----------------------------------------------------------------------===//
412476
// DecomposeOp
413477
//===----------------------------------------------------------------------===//

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,12 @@ def __init__(
4444
loc=None,
4545
ip=None,
4646
):
47-
# No other types are allowed, so hard-code those here.
48-
allocated_buffer_type = transform.AnyValueType.get()
49-
new_ops_type = transform.AnyOpType.get()
50-
5147
if isinstance(memory_space, int):
5248
memory_space = str(memory_space)
5349
if isinstance(memory_space, str):
5450
memory_space = Attribute.parse(memory_space)
5551

5652
super().__init__(
57-
allocated_buffer_type,
58-
new_ops_type,
5953
target,
6054
memory_space=memory_space,
6155
memcpy_op=memcpy_op,
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: @promote_in0
4+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}, %{{.*}})
5+
// CHECK: %[[C0:.+]] = arith.constant 0
6+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
7+
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64}
8+
// CHECK: %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]]
9+
// CHECK: linalg.matmul ins(%[[MAT]], %{{.*}}
10+
func.func @promote_in0(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
11+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
12+
outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
13+
return %0 : tensor<?x?xf32>
14+
}
15+
16+
module attributes {transform.with_named_sequence} {
17+
transform.named_sequence @__transform_main(%root: !transform.any_op) {
18+
%mm = transform.structured.match ops{["linalg.matmul"]} in %root
19+
: (!transform.any_op) -> !transform.any_op
20+
%op0 = transform.get_operand %mm[0]
21+
: (!transform.any_op) -> !transform.any_value
22+
transform.structured.promote_tensor to 1 %op0 : !transform.any_value
23+
transform.yield
24+
}
25+
}
26+
27+
// -----
28+
29+
// CHECK-LABEL: @promote_out
30+
// CHECK-SAME: (%{{.*}}: tensor<?x42xf32>, %{{.*}}: tensor<?x42xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
31+
func.func @promote_out(%arg0: tensor<?x42xf32>, %arg1: tensor<?x42xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
32+
// CHECK: %[[C0:.+]] = arith.constant 0
33+
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
34+
// CHECK: %[[C1:.+]] = arith.constant 1
35+
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
36+
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64}
37+
// CHECK-NOT: materialize_in_destination
38+
// CHECK: linalg.add {{.*}} outs(%[[ALLOC]]
39+
%0 = linalg.add ins(%arg0, %arg1 : tensor<?x42xf32>, tensor<?x42xf32>)
40+
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
41+
return %0 : tensor<?x?xf32>
42+
}
43+
44+
module attributes {transform.with_named_sequence} {
45+
transform.named_sequence @__transform_main(%root: !transform.any_op) {
46+
%la = transform.structured.match ops{["linalg.add"]} in %root
47+
: (!transform.any_op) -> !transform.any_op
48+
%init = transform.get_operand %la[2]
49+
: (!transform.any_op) -> !transform.any_value
50+
transform.structured.promote_tensor to 1 %init : !transform.any_value
51+
52+
transform.yield
53+
}
54+
}
55+
56+
// -----
57+
58+
// CHECK-LABEL: @promote_in0_out_bufferize
59+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
60+
func.func @promote_in0_out_bufferize(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
61+
// CHECK: %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>>
62+
// CHECK: %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
63+
// CHECK: %{{.+}} = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
64+
// CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
65+
// CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
66+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
67+
// CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
68+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
69+
// CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
70+
// CHECK: %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref<?x?xf32, 1>
71+
// CHECK: %{{.+}} = arith.constant 0 : index
72+
// CHECK: %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref<?x42xf32, strided<[?, ?], offset: ?>>
73+
// CHECK: %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref<?x42xf32, 1>
74+
// CHECK: memref.copy %[[IN0]], %[[ALLOC_IN]] : memref<?x42xf32, strided<[?, ?], offset: ?>> to memref<?x42xf32, 1>
75+
// CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref<?x42xf32, 1>, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref<?x?xf32, 1>)
76+
%0 = linalg.add ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
77+
outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
78+
return %0 : tensor<?x?xf32>
79+
}
80+
81+
module attributes {transform.with_named_sequence} {
82+
transform.named_sequence @__transform_main(%root: !transform.any_op) {
83+
%la = transform.structured.match ops{["linalg.add"]} in %root
84+
: (!transform.any_op) -> !transform.any_op
85+
%op0 = transform.get_operand %la[0]
86+
: (!transform.any_op) -> !transform.any_value
87+
transform.structured.promote_tensor to 1 %op0 : !transform.any_value
88+
89+
%init = transform.get_operand %la[2]
90+
: (!transform.any_op) -> !transform.any_value
91+
transform.structured.promote_tensor to 1 %init : !transform.any_value
92+
93+
%func = transform.structured.match ops{["func.func"]} in %root
94+
: (!transform.any_op) -> !transform.any_op
95+
96+
%bufferized = transform.bufferization.one_shot_bufferize %func
97+
: (!transform.any_op) -> !transform.any_op
98+
99+
transform.yield
100+
}
101+
}
102+
103+
104+

0 commit comments

Comments
 (0)