@@ -1087,43 +1087,44 @@ LogicalResult GenericOp::verify() { return success(); }
10871087
10881088namespace {
10891089
1090- // / Remove generic operations (on tensors) that are just copying
1090+ // / Remove any linalg operation (on tensors) that are just copying
10911091// / the values from inputs to the results. Requirements are
10921092// / 1) All iterator types are parallel
10931093// / 2) The body contains just a yield operation with the yielded values being
10941094// / the arguments corresponding to the operands.
1095- struct EraseIdentityGenericOp : public OpRewritePattern <GenericOp> {
1096- using OpRewritePattern<GenericOp>::OpRewritePattern;
1095+ template <typename OpTy>
1096+ struct EraseIdentityLinalgOp : public OpRewritePattern <OpTy> {
1097+ using OpRewritePattern<OpTy>::OpRewritePattern;
10971098
1098- LogicalResult matchAndRewrite (GenericOp genericOp ,
1099+ LogicalResult matchAndRewrite (OpTy linalgOp ,
10991100 PatternRewriter &rewriter) const override {
11001101 // Check all indexing maps are identity.
1101- if (llvm::any_of (genericOp .getIndexingMapsArray (),
1102+ if (llvm::any_of (linalgOp .getIndexingMapsArray (),
11021103 [](AffineMap map) { return !map.isIdentity (); }))
11031104 return failure ();
11041105
11051106 // Check that the body of the linalg operation is just a linalg.yield
11061107 // operation.
1107- Block &body = genericOp. getRegion ().front ();
1108+ Block &body = linalgOp-> getRegion (0 ).front ();
11081109 if (!llvm::hasSingleElement (body))
11091110 return failure ();
11101111 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator ());
11111112 if (!yieldOp)
11121113 return failure ();
11131114
11141115 // In the buffer case, we need to check exact buffer equality.
1115- if (genericOp .hasPureBufferSemantics ()) {
1116- if (genericOp .getNumDpsInputs () == 1 && genericOp .getNumDpsInits () == 1 &&
1117- genericOp .getDpsInputOperand (0 )->get () ==
1118- genericOp .getDpsInitOperand (0 )->get ()) {
1119- rewriter.eraseOp (genericOp );
1116+ if (linalgOp .hasPureBufferSemantics ()) {
1117+ if (linalgOp .getNumDpsInputs () == 1 && linalgOp .getNumDpsInits () == 1 &&
1118+ linalgOp .getDpsInputOperand (0 )->get () ==
1119+ linalgOp .getDpsInitOperand (0 )->get ()) {
1120+ rewriter.eraseOp (linalgOp );
11201121 return success ();
11211122 }
11221123 return failure ();
11231124 }
11241125
11251126 // Mixed semantics is not supported yet.
1126- if (!genericOp .hasPureTensorSemantics ())
1127+ if (!linalgOp .hasPureTensorSemantics ())
11271128 return failure ();
11281129
11291130 // Get the argument number of the returned values. That is the operand
@@ -1134,8 +1135,8 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11341135 if (!yieldArg || yieldArg.getOwner () != &body)
11351136 return failure ();
11361137 unsigned argumentNumber = yieldArg.getArgNumber ();
1137- Value returnedArg = genericOp ->getOperand (argumentNumber);
1138- Type resultType = genericOp ->getResult (yieldVal.index ()).getType ();
1138+ Value returnedArg = linalgOp ->getOperand (argumentNumber);
1139+ Type resultType = linalgOp ->getResult (yieldVal.index ()).getType ();
11391140 // The input can have a different type than the result, e.g. a dynamic
11401141 // input dimension can be turned into a static output dimension.
11411142 Type returnType = returnedArg.getType ();
@@ -1145,21 +1146,21 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11451146 if (sparse_tensor::getSparseTensorEncoding (returnType) ||
11461147 sparse_tensor::getSparseTensorEncoding (resultType))
11471148 returnedArg = rewriter.create <sparse_tensor::ConvertOp>(
1148- genericOp .getLoc (), resultType, returnedArg);
1149+ linalgOp .getLoc (), resultType, returnedArg);
11491150 else {
11501151 if (!tensor::CastOp::areCastCompatible (returnedArg.getType (),
11511152 resultType))
11521153 return failure ();
11531154 returnedArg = rewriter.create <tensor::CastOp>(
1154- genericOp .getLoc (), resultType, returnedArg);
1155+ linalgOp .getLoc (), resultType, returnedArg);
11551156 }
11561157 }
11571158 returnedArgs.push_back (returnedArg);
11581159 }
11591160
1160- if (returnedArgs.size () != genericOp ->getNumResults ())
1161+ if (returnedArgs.size () != linalgOp ->getNumResults ())
11611162 return failure ();
1162- rewriter.replaceOp (genericOp , returnedArgs);
1163+ rewriter.replaceOp (linalgOp , returnedArgs);
11631164 return success ();
11641165 }
11651166};
@@ -1168,7 +1169,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11681169
11691170void GenericOp::getCanonicalizationPatterns (RewritePatternSet &results,
11701171 MLIRContext *context) {
1171- results.add <EraseIdentityGenericOp >(context);
1172+ results.add <EraseIdentityLinalgOp<GenericOp> >(context);
11721173}
11731174
11741175LogicalResult GenericOp::fold (FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
@@ -1907,6 +1908,11 @@ void BroadcastOp::getEffects(
19071908 getDpsInits ());
19081909}
19091910
1911+ void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
1912+ MLIRContext *context) {
1913+ results.add <EraseIdentityLinalgOp<BroadcastOp>>(context);
1914+ }
1915+
19101916// ===----------------------------------------------------------------------===//
19111917// YieldOp
19121918// ===----------------------------------------------------------------------===//
0 commit comments