@@ -75,11 +75,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
7575// of the fused producer & consumer after the fusion can still compute the
7676// bounds of the op.
7777static bool isOpOperandCanBeDroppedAfterFusedLinalgs (
78- GenericOp producer, GenericOp consumer,
78+ LinalgOp producer, LinalgOp consumer,
7979 ArrayRef<OpOperand *> opOperandsToIgnore) {
8080 SmallVector<AffineMap> indexingMaps;
8181
82- SmallVector<GenericOp > ops = {producer, consumer};
82+ SmallVector<LinalgOp > ops = {producer, consumer};
8383 for (auto &op : ops) {
8484 for (auto &opOperand : op->getOpOperands ()) {
8585 if (llvm::is_contained (opOperandsToIgnore, &opOperand)) {
@@ -108,7 +108,7 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
108108// / agree with the result of this method. This function gives a prediction based
109109// / on an optimized fusion.
110110llvm::SmallDenseSet<int > mlir::linalg::getPreservedProducerResults (
111- GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
111+ LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) {
112112 llvm::SmallDenseSet<int > preservedProducerResults;
113113 llvm::SmallVector<OpOperand *> opOperandsToIgnore;
114114
@@ -138,8 +138,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
138138 if (!fusedOperand)
139139 return false ;
140140
141- auto producer = fusedOperand->get ().getDefiningOp <GenericOp >();
142- auto consumer = dyn_cast<GenericOp >(fusedOperand->getOwner ());
141+ auto producer = fusedOperand->get ().getDefiningOp <LinalgOp >();
142+ auto consumer = dyn_cast<LinalgOp >(fusedOperand->getOwner ());
143143
144144 // Check producer and consumer are generic ops.
145145 if (!producer || !consumer)
@@ -213,16 +213,16 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
213213// / Generate the region of the fused tensor operation. The region of the fused
214214// / op must be empty.
215215static void generateFusedElementwiseOpRegion (
216- RewriterBase &rewriter, GenericOp fusedOp,
216+ RewriterBase &rewriter, LinalgOp fusedOp,
217217 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
218218 unsigned nloops, llvm::SmallDenseSet<int > &preservedProducerResults) {
219- auto producer = cast<GenericOp >(fusedOperand->get ().getDefiningOp ());
220- auto consumer = cast<GenericOp >(fusedOperand->getOwner ());
219+ auto producer = cast<LinalgOp >(fusedOperand->get ().getDefiningOp ());
220+ auto consumer = cast<LinalgOp >(fusedOperand->getOwner ());
221221 // Build the region of the fused op.
222222 Block &producerBlock = producer->getRegion (0 ).front ();
223223 Block &consumerBlock = consumer->getRegion (0 ).front ();
224224 OpBuilder::InsertionGuard guard (rewriter);
225- Block *fusedBlock = rewriter.createBlock (&fusedOp. getRegion ());
225+ Block *fusedBlock = rewriter.createBlock (&fusedOp-> getRegion (0 ));
226226 IRMapping mapper;
227227
228228 // 2. Add an index operation for every fused loop dimension and use the
@@ -329,7 +329,7 @@ static void generateFusedElementwiseOpRegion(
329329 rewriter.create <YieldOp>(fusedOp.getLoc (), fusedYieldValues);
330330
331331 // Sanity checks.
332- assert (fusedBlock->getNumArguments () == fusedOp. getNumOperands () &&
332+ assert (fusedBlock->getNumArguments () == fusedOp-> getNumOperands () &&
333333 " Ill-formed GenericOp region" );
334334}
335335
@@ -339,8 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
339339 assert (areElementwiseOpsFusable (fusedOperand) &&
340340 " expected elementwise operation pre-conditions to pass" );
341341 auto producerResult = cast<OpResult>(fusedOperand->get ());
342- auto producer = cast<GenericOp >(producerResult.getOwner ());
343- auto consumer = cast<GenericOp >(fusedOperand->getOwner ());
342+ auto producer = cast<LinalgOp >(producerResult.getOwner ());
343+ auto consumer = cast<LinalgOp >(fusedOperand->getOwner ());
344344 // TODO: allow fusing the producer of an output operand.
345345 assert (consumer.isDpsInput (fusedOperand) &&
346346 " expected producer of input operand" );
@@ -415,12 +415,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
415415 }
416416
417417 // Generate the fused op.
418+ // auto fusedOp = cloneWithoutRegions(rewriter, consumer,
419+ // fusedResultTypes, fusedInputOperands);
420+ // fusedOp.setIndexingMapsAttr(idxMap);
421+ // fusedOp.setIteratorTypesAttr(itTp);
418422 auto fusedOp = rewriter.create <GenericOp>(
419423 consumer.getLoc (), fusedResultTypes, fusedInputOperands,
420- fusedOutputOperands, rewriter.getAffineMapArrayAttr (fusedIndexMaps),
421- consumer.getIteratorTypes (),
422- /* doc=*/ nullptr ,
423- /* library_call=*/ nullptr );
424+ fusedOutputOperands, fusedIndexMaps,
425+ consumer.getIteratorTypesArray ());
424426 if (!fusedOp.getShapesToLoopsMap ()) {
425427 // Fused op has invalid indexing maps. Typically this means something is off
426428 // in the input, but going ahead here would result in verification errors.
@@ -459,14 +461,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
459461
460462namespace {
461463// / Patterns to fuse a generic op, with the producer of its operands.
462- class FuseElementwiseOps : public OpRewritePattern <GenericOp > {
464+ class FuseElementwiseOps : public OpInterfaceRewritePattern <LinalgOp > {
463465public:
464466 FuseElementwiseOps (MLIRContext *context, ControlFusionFn fun,
465467 PatternBenefit benefit = 1 )
466- : OpRewritePattern<GenericOp >(context, benefit),
468+ : OpInterfaceRewritePattern<LinalgOp >(context, benefit),
467469 controlFn (std::move(fun)) {}
468470
469- LogicalResult matchAndRewrite (GenericOp genericOp,
471+ LogicalResult matchAndRewrite (LinalgOp genericOp,
470472 PatternRewriter &rewriter) const override {
471473 // Find the first operand that is defined by another generic op on tensors.
472474 for (OpOperand &opOperand : genericOp->getOpOperands ()) {
@@ -493,7 +495,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
493495 rewriter.eraseOp (genericOp);
494496 return success ();
495497 }
496- return failure ( );
498+ return rewriter. notifyMatchFailure (genericOp, " no fusable operands " );
497499 }
498500
499501private:
0 commit comments