@@ -224,28 +224,8 @@ static void generateFusedElementwiseOpRegion(
224224 auto consumer = cast<LinalgOp>(fusedOperand->getOwner ());
225225 // Build the region of the fused op.
226226
227- // Since some ops, like `linalg.map`, do not have block arguments for init
228- // operands then we first "generalize" the block by adding arguments for init
229- // operands when they aren't present. We detect this case by checking if
230- // `getOpOperandsMatchingBBargs() == getDpsInputOperands()`.
231- // TODO: This is hacky and should not be merged. Keeping for now for testing
232- // purposes in the meantime, but need a better way
233227 Block &producerBlock = producer->getRegion (0 ).front ();
234- bool addOutputArgsProducer =
235- producer.getOpOperandsMatchingBBargs () == producer.getDpsInputOperands ();
236- if (addOutputArgsProducer) {
237- for (auto init : producer.getDpsInits ())
238- producerBlock.addArgument (getElementTypeOrSelf (init.getType ()),
239- producer.getLoc ());
240- }
241228 Block &consumerBlock = consumer->getRegion (0 ).front ();
242- bool addOutputArgsConsumer =
243- consumer.getOpOperandsMatchingBBargs () == consumer.getDpsInputOperands ();
244- if (addOutputArgsConsumer) {
245- for (auto init : consumer.getDpsInits ())
246- consumerBlock.addArgument (getElementTypeOrSelf (init.getType ()),
247- consumer.getLoc ());
248- }
249229 OpBuilder::InsertionGuard guard (rewriter);
250230 Block *fusedBlock = rewriter.createBlock (&fusedOp->getRegion (0 ));
251231 IRMapping mapper;
@@ -355,14 +335,6 @@ static void generateFusedElementwiseOpRegion(
355335 // Sanity checks.
356336 assert (fusedBlock->getNumArguments () == fusedOp->getNumOperands () &&
357337 " Ill-formed GenericOp region" );
358- // Erase added args in case that the ops are still live after fusion.
359- // TODO: Remove along with hacky code above.
360- if (addOutputArgsProducer)
361- producerBlock.eraseArguments (producer.getNumDpsInputs (),
362- producer.getNumDpsInits ());
363- if (addOutputArgsConsumer)
364- consumerBlock.eraseArguments (consumer.getNumDpsInputs (),
365- consumer.getNumDpsInits ());
366338}
367339
368340FailureOr<mlir::linalg::ElementwiseOpFusionResult>
0 commit comments