Skip to content

Commit b65f967

Browse files
committed
[mlir][linalg] revise based on review comments
1 parent a38ba01 commit b65f967

File tree

5 files changed

+138
-85
lines changed

5 files changed

+138
-85
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,18 @@ def LinalgStructuredInterface
243243
utils::IteratorType::parallel);
244244
}]
245245
>,
246+
InterfaceMethod<
247+
/*desc=*/[{
248+
Return true if all loops are parallel.
249+
}],
250+
/*retTy=*/"bool",
251+
/*methodName=*/"isAllParallelLoops",
252+
/*args=*/(ins),
253+
/*methodBody=*/"",
254+
/*defaultImplementation=*/[{
255+
return getNumParallelLoops() == getNumParallelLoops();
256+
}]
257+
>,
246258
InterfaceMethod<
247259
/*desc=*/[{
248260
Return the dims that are parallel loops.
@@ -327,6 +339,18 @@ def LinalgStructuredInterface
327339
return !getBlock()->getArgument(bbArgNumber).use_empty();
328340
}]
329341
>,
342+
InterfaceMethod<
343+
/*desc=*/[{
344+
Returns true only if linalgOp takes one input and produces one result.
345+
}],
346+
/*retTy=*/"bool",
347+
/*methodName=*/"isSingleInputOutput",
348+
/*args=*/(ins),
349+
/*methodBody=*/"",
350+
/*defaultImplementation=*/[{
351+
return $_op.getNumDpsInputs() == 1 && $_op.getNumDpsInits() == 1;
352+
}]
353+
>,
330354
InterfaceMethod<
331355
/*desc=*/[{
332356
Return true if `opOperand` is an init tensor. This is true when it is

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,24 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
210210
}
211211

212212
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
213+
214+
// Return true only if GenericOp has a single input and single
215+
// output, and the body is a single yieldOp that yields the input.
216+
// This check is useful when trying to determine if the op is
217+
// essentially a transpose, broadcast, copy or something like that.
218+
bool isSingleYieldOp() {
219+
if (!isSingleInputOutput())
220+
return false;
221+
Block *body = getBody();
222+
if (body->getOperations().size() != 1)
223+
return false;
224+
225+
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
226+
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
227+
yieldOp->getOperand(0) != body->getArgument(0))
228+
return false;
229+
return true;
230+
}
213231
}];
214232

215233
let hasCanonicalizer = 1;

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 67 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -50,66 +50,40 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
5050
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
5151
}
5252

53-
// Returns true if all loops of the linalgOp are parallel
54-
static bool isAllParallel(LinalgOp op) {
55-
return op.getNumParallelLoops() == op.getNumLoops();
56-
}
57-
58-
// Returns true if and only if linalgOp takes one input and one init.
59-
static bool isSingleInputOutput(LinalgOp op) {
60-
return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
61-
}
62-
// Returns true if genericOp body is just a yieldOp that yields
63-
// input operand as result.
64-
static bool isSingleYieldOp(GenericOp op) {
65-
if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
66-
return false;
67-
68-
Block *body = op.getBody();
69-
if (body->getOperations().size() != 1)
70-
return false;
71-
72-
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
73-
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
74-
yieldOp->getOperand(0) != body->getArgument(0))
75-
return false;
76-
return true;
77-
}
78-
7953
//===----------------------------------------------------------------------===//
8054
// CopyOpInterface implementation
8155
//===----------------------------------------------------------------------===//
8256

83-
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
84-
// Structural and operands
85-
if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
57+
bool linalg::isaCopyOpInterface(LinalgOp op) {
58+
// Check all loops are parallel and linalgOp is single input and output.
59+
if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
8660
return false;
8761

88-
auto mapRange = linalgOp.getIndexingMapsArray();
62+
auto mapRange = op.getIndexingMapsArray();
8963
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
9064
!mapRange.back().isIdentity()) {
9165
return false;
9266
}
9367
// Region.
94-
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
68+
return llvm::hasSingleElement(op.getBlock()->getOperations());
9569
}
9670

9771
//===----------------------------------------------------------------------===//
9872
// FillOpInterface implementation
9973
//===----------------------------------------------------------------------===//
100-
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
74+
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
10175
// Structural.
102-
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
103-
!isSingleYieldOp(genericOp))
76+
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
77+
!op.isSingleYieldOp())
10478
return std::nullopt;
10579

10680
// Input should be referenced and init should not.
107-
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
108-
genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
81+
if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
82+
op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
10983
return std::nullopt;
11084

111-
OpOperand *value = genericOp.getDpsInputOperand(0);
112-
if (!genericOp.isScalar(value))
85+
OpOperand *value = op.getDpsInputOperand(0);
86+
if (!op.isScalar(value))
11387
return std::nullopt;
11488
return value->get();
11589
}
@@ -118,27 +92,30 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
11892
// BroadcastOpInterface implementation
11993
//===----------------------------------------------------------------------===//
12094
std::optional<SmallVector<int64_t>>
121-
linalg::isaBroadcastOpInterface(GenericOp genericOp) {
95+
linalg::isaBroadcastOpInterface(GenericOp op) {
12296
// Structural.
123-
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
124-
!isSingleYieldOp(genericOp))
97+
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
98+
!op.isSingleYieldOp())
12599
return std::nullopt;
126100

127-
auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
128-
auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
129-
if (!isa<MemRefType, RankedTensorType>(t0) ||
130-
!isa<MemRefType, RankedTensorType>(t1))
101+
auto srcTy = op.getDpsInputOperand(0)->get().getType();
102+
auto dstTy = op.getDpsInitOperand(0)->get().getType();
103+
if (!isa<MemRefType, RankedTensorType>(srcTy) ||
104+
!isa<MemRefType, RankedTensorType>(dstTy))
131105
return std::nullopt;
132106

133-
// Check output is identity map. Injective function could also be
134-
// a permutation of indices and expressible in linalg.generic but
135-
// is not expressible for named broadcast op.
136-
auto dstMap = genericOp.getIndexingMapsArray()[1];
107+
// Check output is identity map. Broadcast could additionally be
108+
// employing permutation of indices and that would be expressible
109+
// in linalg.generic but is not expressible for named broadcast op.
110+
auto dstMap = op.getIndexingMapsArray()[1];
137111
if (!dstMap.isIdentity())
138112
return std::nullopt;
139113

140114
SmallVector<int64_t> position;
141-
auto srcMap = genericOp.getIndexingMapsArray()[0];
115+
auto srcMap = op.getIndexingMapsArray()[0];
116+
117+
if (srcMap.getResults().size() >= dstMap.getResults().size())
118+
return std::nullopt;
142119

143120
// Check input map is monotonically increasing DimIds.
144121
for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
@@ -153,6 +130,7 @@ linalg::isaBroadcastOpInterface(GenericOp genericOp) {
153130

154131
SmallVector<int64_t> broadcastedDims;
155132
auto numDims = srcMap.getNumDims();
133+
// This is quadratic but number of items is generally small.
156134
for (auto dim : llvm::seq<int64_t>(0, numDims)) {
157135
if (!llvm::is_contained(position, dim))
158136
broadcastedDims.push_back(dim);
@@ -164,86 +142,92 @@ linalg::isaBroadcastOpInterface(GenericOp genericOp) {
164142
// TranposeOpInterface implementation
165143
//===----------------------------------------------------------------------===//
166144
std::optional<SmallVector<int64_t>>
167-
linalg::isaTransposeOpInterface(GenericOp genericOp) {
168-
// Structural.
169-
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
170-
!isSingleYieldOp(genericOp))
145+
linalg::isaTransposeOpInterface(GenericOp op) {
146+
// To specialize as a transpose op, the genericOp must be
147+
// all parallel loops, single input, single output, and its body
148+
// should be just a yield op, yielding input as output as is (no compute).
149+
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
150+
!op.isSingleYieldOp())
171151
return std::nullopt;
172152

173-
// mapping checks.
174-
auto mapRange = genericOp.getIndexingMapsArray();
175-
if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
176-
!mapRange.front().isPermutation())
153+
auto mapRange = op.getIndexingMapsArray();
154+
if (mapRange.size() != 2)
177155
return std::nullopt;
178156

179-
SmallVector<int64_t> permutation;
180-
auto map = mapRange.front();
181-
for (unsigned i = 0; i < map.getNumResults(); ++i) {
182-
auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
183-
permutation.push_back(expr.getPosition());
157+
auto mapOfInput = mapRange.front();
158+
auto mapOfResult = mapRange.back();
159+
160+
// linalg.transpose permutes the dimensions of input using this
161+
// rule: dim(result, i) = dim(input, permutation[i])
162+
if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
163+
return std::nullopt;
164+
165+
SmallVector<int64_t> permutation(mapOfInput.getNumDims());
166+
for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
167+
auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
168+
permutation[expr.getPosition()] = i;
184169
}
185170
return permutation;
186171
}
187172

188173
//===----------------------------------------------------------------------===//
189174
// Elementwise Single Unary/Binary-OpInterface implementation
190175
//===----------------------------------------------------------------------===//
191-
static bool
192-
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
193-
unsigned arity) {
176+
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
177+
unsigned arity) {
194178
// Check all loops are parallel.
195-
if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
179+
if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
196180
return false;
197181

198182
// Check there are arity-inputs, 1-output and all are identity-maps.
199-
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
200-
!llvm::all_of(genericOp.getIndexingMapsArray(),
183+
if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
184+
!llvm::all_of(op.getIndexingMapsArray(),
201185
[](AffineMap map) { return map.isIdentity(); }))
202186
return false;
203187

204188
// Init should not be referenced for elementwise operations.
205-
if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
189+
if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
206190
return false;
207191

208192
// A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
209193
// as resulting from producer-consumer fusion. Here, we restrict to two ops in
210194
// the body, where the first is the elementwise single op and the second a
211195
// yield.
212-
Block *body = genericOp.getBody();
196+
Block *body = op.getBody();
213197
if (body->getOperations().size() != 2)
214198
return false;
215199

216-
Operation *op = &body->front();
217-
if (op->getNumOperands() != arity || op->getNumResults() != 1)
200+
Operation *oper = &body->front();
201+
if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
218202
return false;
219203

220204
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
221205
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
222-
yieldOp->getOperand(0).getDefiningOp() != op)
206+
yieldOp->getOperand(0).getDefiningOp() != oper)
223207
return false;
224208
return true;
225209
}
226210

227-
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
211+
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
228212
// All basic elemwise checks.
229-
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
213+
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
230214
return false;
231215

232216
// Check input is actully used.
233-
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
217+
if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
234218
return false;
235219
return true;
236220
}
237221

238-
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
239-
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
222+
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
223+
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
240224
return false;
241225

242226
// Check both inputs are used (elementwise).
243-
OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
244-
OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
245-
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
246-
!genericOp.payloadUsesValueFromOperand(inputOpOperand1))
227+
OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
228+
OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
229+
if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
230+
!op.payloadUsesValueFromOperand(inputOpOperand1))
247231
return false;
248232
return true;
249233
}
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
22

3-
// CHECK-LABEL: linalg_transpose
3+
// CHECK-LABEL: transpose2D
44
// CHECK-SAME: %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
55
// CHECK-NOT: linalg.generic
66
// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
77
//
8-
func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
8+
func.func @transpose2D(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
99
%res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
1010
return %res : tensor<64x16xf32>
1111
}
12+
13+
14+
// CHECK-LABEL: transpose3D
15+
// CHECK-SAME: %[[A:.+]]: tensor<7x8x9xf32>, %[[Out:.+]]: tensor<9x7x8xf32>
16+
// CHECK-NOT: linalg.generic
17+
// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<7x8x9xf32>) outs(%[[Out]] : tensor<9x7x8xf32>) permutation = [2, 0, 1]
18+
//
19+
func.func @transpose3D(%arg0: tensor<7x8x9xf32>, %arg1: tensor<9x7x8xf32>) -> tensor<9x7x8xf32> {
20+
%transposed = linalg.transpose ins(%arg0 : tensor<7x8x9xf32>) outs(%arg1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
21+
return %transposed : tensor<9x7x8xf32>
22+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
2+
3+
#map = affine_map<(d0, d1, d2) -> (d1, d0)>
4+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
5+
// This test checks that linalg.generic does not get incorrectly specialized to transform or broadcast.
6+
// CHECK-LABEL: @transpose_and_broadcast
7+
// CHECK: linalg.generic
8+
func.func @transpose_and_broadcast(%arg0: tensor<7x8xf32>, %arg1: tensor<8x7x9xf32>) -> tensor<8x7x9xf32> {
9+
%0 = linalg.generic
10+
{indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]}
11+
ins(%arg0 : tensor<7x8xf32>) outs(%arg1 : tensor<8x7x9xf32>) {
12+
^bb0(%in: f32, %out: f32):
13+
linalg.yield %in : f32
14+
} -> tensor<8x7x9xf32>
15+
return %0 : tensor<8x7x9xf32>
16+
}

0 commit comments

Comments
 (0)