@@ -226,17 +226,21 @@ static void generateFusedElementwiseOpRegion(
226226 // Since some ops, like `linalg.map`, do not have block arguments for init
227227 // operands then we first "generalize" the block by adding arguments for init
228228 // operands when they aren't present. We detect this case by checking if
229- // `getOpOperandsMatchingBBargs() == getDpsInputOperands()
229+ // `getOpOperandsMatchingBBargs() == getDpsInputOperands()`.
230+ // TODO: This is hacky and should not be merged. Keeping for now for testing
231+ // purposes in the meantime, but need a better way
230232 Block &producerBlock = producer->getRegion (0 ).front ();
231- if (producer.getOpOperandsMatchingBBargs () ==
232- producer.getDpsInputOperands ()) {
233+ bool addOutputArgsProducer =
234+ producer.getOpOperandsMatchingBBargs () == producer.getDpsInputOperands ();
235+ if (addOutputArgsProducer) {
233236 for (auto init : producer.getDpsInits ())
234237 producerBlock.addArgument (getElementTypeOrSelf (init.getType ()),
235238 producer.getLoc ());
236239 }
237240 Block &consumerBlock = consumer->getRegion (0 ).front ();
238- if (consumer.getOpOperandsMatchingBBargs () ==
239- consumer.getDpsInputOperands ()) {
241+ bool addOutputArgsConsumer =
242+ consumer.getOpOperandsMatchingBBargs () == consumer.getDpsInputOperands ();
243+ if (addOutputArgsConsumer) {
240244 for (auto init : consumer.getDpsInits ())
241245 consumerBlock.addArgument (getElementTypeOrSelf (init.getType ()),
242246 consumer.getLoc ());
@@ -350,6 +354,14 @@ static void generateFusedElementwiseOpRegion(
350354 // Sanity checks.
351355 assert (fusedBlock->getNumArguments () == fusedOp->getNumOperands () &&
352356 " Ill-formed GenericOp region" );
357+ // Erase added args in case that the ops are still live after fusion.
358+ // TODO: Remove along with hacky code above.
359+ if (addOutputArgsProducer)
360+ producerBlock.eraseArguments (producer.getNumDpsInputs (),
361+ producer.getNumDpsInits ());
362+ if (addOutputArgsConsumer)
363+ consumerBlock.eraseArguments (consumer.getNumDpsInputs (),
364+ consumer.getNumDpsInits ());
353365}
354366
355367FailureOr<mlir::linalg::ElementwiseOpFusionResult>
0 commit comments