1111#include " llvm/ADT/SetOperations.h"
1212#include " llvm/ADT/SmallBitVector.h"
1313#include " llvm/ADT/SmallVector.h"
14+ #include " llvm/Support/DebugLog.h"
1415#include " mlir/Dialect/Affine/IR/AffineOps.h"
1516#include " mlir/Dialect/Arith/IR/Arith.h"
1617#include " mlir/Dialect/Complex/IR/Complex.h"
2021#include " mlir/IR/MLIRContext.h"
2122#include " mlir/IR/TypeUtilities.h"
2223
24+ #define DEBUG_TYPE " iree-linalgext-utils-matchutils"
25+
2326// This is based on the upstream implementation of the
2427// Linalg::ContractionOpInterface.
2528
@@ -121,17 +124,17 @@ inferIteratorsFromOutMap(AffineMap map) {
121124
122125bool isScaledContractionBody (Block &block) {
123126 if (block.empty () || !block.back ().mightHaveTrait <OpTrait::IsTerminator>()) {
124- llvm::errs () << " no terminator in the block" ;
127+ LDBG () << " no terminator in the block" ;
125128 return false ;
126129 }
127130 if (block.getNumArguments () != 5 ) {
128- llvm::errs () << " expected block with 3 arguments" ;
131+ LDBG () << " expected block with 3 arguments" ;
129132 return false ;
130133 }
131134
132135 Operation *terminator = block.getTerminator ();
133136 if (terminator->getNumOperands () != 1 ) {
134- llvm::errs () << " expected terminator with 1 operand" ;
137+ LDBG () << " expected terminator with 1 operand" ;
135138 return false ;
136139 }
137140
@@ -156,7 +159,7 @@ bool isScaledContractionBody(Block &block) {
156159 Value yielded = getSourceSkipUnary (terminator->getOperand (0 ));
157160 Operation *reductionOp = yielded.getDefiningOp ();
158161 if (reductionOp->getNumResults () != 1 || reductionOp->getNumOperands () != 2 ) {
159- llvm::errs () << " expected reduction op to be binary" ;
162+ LDBG () << " expected reduction op to be binary" ;
160163 return false ;
161164 }
162165
@@ -165,9 +168,8 @@ bool isScaledContractionBody(Block &block) {
165168
166169 if (reductionLHS != block.getArgument (4 ) &&
167170 reductionRHS != block.getArgument (4 )) {
168- llvm::errs ()
169- << " expected reduction to take block argument #4 as one of the "
170- " operands (modulo unary casts)" ;
171+ LDBG () << " expected reduction to take block argument #4 as one of the "
172+ " operands (modulo unary casts)" ;
171173 return false ;
172174 }
173175
@@ -176,11 +178,11 @@ bool isScaledContractionBody(Block &block) {
176178 Operation *elementwiseOp = contributed.getDefiningOp ();
177179 if (!elementwiseOp || elementwiseOp->getNumResults () != 1 ||
178180 elementwiseOp->getNumOperands () != 2 ) {
179- llvm::errs () << " expected elementwise op to be binary" ;
181+ LDBG () << " expected elementwise op to be binary" ;
180182 return false ;
181183 }
182184 if (!isValidScaledMmaPair (reductionOp, elementwiseOp)) {
183- llvm::errs () << " expected reduction/elementwise op kind not satisfied" ;
185+ LDBG () << " expected reduction/elementwise op kind not satisfied" ;
184186 return false ;
185187 }
186188
@@ -193,9 +195,8 @@ bool isScaledContractionBody(Block &block) {
193195 return true ;
194196 }
195197
196- llvm::errs ()
197- << " expected elementwise op to apply to block arguments (modulo unary "
198- " casts)" ;
198+ LDBG () << " expected elementwise op to apply to block arguments (modulo unary "
199+ " casts)" ;
199200 return false ;
200201}
201202
0 commit comments