Skip to content

Commit 7d402c1

Browse files
committed
remove block args that were added (hacky)
1 parent cf67ab6 commit 7d402c1

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,17 +226,21 @@ static void generateFusedElementwiseOpRegion(
226226
// Since some ops, like `linalg.map`, do not have block arguments for init
227227
// operands then we first "generalize" the block by adding arguments for init
228228
// operands when they aren't present. We detect this case by checking if
229-
// `getOpOperandsMatchingBBargs() == getDpsInputOperands()
229+
// `getOpOperandsMatchingBBargs() == getDpsInputOperands()`.
230+
// TODO: This is hacky and should not be merged. Keeping for now for testing
231+
// purposes in the meantime, but need a better way
230232
Block &producerBlock = producer->getRegion(0).front();
231-
if (producer.getOpOperandsMatchingBBargs() ==
232-
producer.getDpsInputOperands()) {
233+
bool addOutputArgsProducer =
234+
producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands();
235+
if (addOutputArgsProducer) {
233236
for (auto init : producer.getDpsInits())
234237
producerBlock.addArgument(getElementTypeOrSelf(init.getType()),
235238
producer.getLoc());
236239
}
237240
Block &consumerBlock = consumer->getRegion(0).front();
238-
if (consumer.getOpOperandsMatchingBBargs() ==
239-
consumer.getDpsInputOperands()) {
241+
bool addOutputArgsConsumer =
242+
consumer.getOpOperandsMatchingBBargs() == consumer.getDpsInputOperands();
243+
if (addOutputArgsConsumer) {
240244
for (auto init : consumer.getDpsInits())
241245
consumerBlock.addArgument(getElementTypeOrSelf(init.getType()),
242246
consumer.getLoc());
@@ -350,6 +354,14 @@ static void generateFusedElementwiseOpRegion(
350354
// Sanity checks.
351355
assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
352356
"Ill-formed GenericOp region");
357+
// Erase added args in case that the ops are still live after fusion.
358+
// TODO: Remove along with hacky code above.
359+
if (addOutputArgsProducer)
360+
producerBlock.eraseArguments(producer.getNumDpsInputs(),
361+
producer.getNumDpsInits());
362+
if (addOutputArgsConsumer)
363+
consumerBlock.eraseArguments(consumer.getNumDpsInputs(),
364+
consumer.getNumDpsInits());
353365
}
354366

355367
FailureOr<mlir::linalg::ElementwiseOpFusionResult>

0 commit comments

Comments
 (0)