Skip to content

Commit ce33596

Browse files
committed
remove hack for linalg.map and update tests
1 parent a12b417 commit ce33596

File tree

2 files changed

+2
-30
lines changed

2 files changed

+2
-30
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -224,28 +224,8 @@ static void generateFusedElementwiseOpRegion(
224224
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
225225
// Build the region of the fused op.
226226

227-
// Since some ops, like `linalg.map`, do not have block arguments for init
228-
// operands then we first "generalize" the block by adding arguments for init
229-
// operands when they aren't present. We detect this case by checking if
230-
// `getOpOperandsMatchingBBargs() == getDpsInputOperands()`.
231-
// TODO: This is hacky and should not be merged. Keeping for now for testing
232-
// purposes in the meantime, but need a better way
233227
Block &producerBlock = producer->getRegion(0).front();
234-
bool addOutputArgsProducer =
235-
producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands();
236-
if (addOutputArgsProducer) {
237-
for (auto init : producer.getDpsInits())
238-
producerBlock.addArgument(getElementTypeOrSelf(init.getType()),
239-
producer.getLoc());
240-
}
241228
Block &consumerBlock = consumer->getRegion(0).front();
242-
bool addOutputArgsConsumer =
243-
consumer.getOpOperandsMatchingBBargs() == consumer.getDpsInputOperands();
244-
if (addOutputArgsConsumer) {
245-
for (auto init : consumer.getDpsInits())
246-
consumerBlock.addArgument(getElementTypeOrSelf(init.getType()),
247-
consumer.getLoc());
248-
}
249229
OpBuilder::InsertionGuard guard(rewriter);
250230
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
251231
IRMapping mapper;
@@ -355,14 +335,6 @@ static void generateFusedElementwiseOpRegion(
355335
// Sanity checks.
356336
assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
357337
"Ill-formed GenericOp region");
358-
// Erase added args in case that the ops are still live after fusion.
359-
// TODO: Remove along with hacky code above.
360-
if (addOutputArgsProducer)
361-
producerBlock.eraseArguments(producer.getNumDpsInputs(),
362-
producer.getNumDpsInits());
363-
if (addOutputArgsConsumer)
364-
consumerBlock.eraseArguments(consumer.getNumDpsInputs(),
365-
consumer.getNumDpsInits());
366338
}
367339

368340
FailureOr<mlir::linalg::ElementwiseOpFusionResult>

mlir/test/Dialect/Linalg/fusion-elementwise.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te
8989
%0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
9090
%1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
9191
%2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
92-
(%in0 : f32, %in1 : f32) {
92+
(%in0 : f32, %in1 : f32, %out : f32) {
9393
%cmp = arith.cmpf olt, %in0, %in1 : f32
9494
linalg.yield %cmp : i1
9595
}
@@ -150,7 +150,7 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso
150150
func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
151151
%fill = tensor.empty() : tensor<8xf32>
152152
%add_exp = linalg.map ins(%arg0, %arg1: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
153-
(%in0 : f32, %in1 : f32) {
153+
(%in0 : f32, %in1 : f32, %out : f32) {
154154
%add = arith.addf %in0, %in1 : f32
155155
%exp = math.exp %add : f32
156156
linalg.yield %exp : f32

0 commit comments

Comments
 (0)