1717#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1818#include " mlir/Pass/Pass.h"
1919
20+ #define DEBUG_TYPE " iree-global-opt-generalize-linalg-named-ops"
21+ #define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
22+ #define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
23+
2024namespace mlir ::iree_compiler::GlobalOptimization {
2125
2226#define GEN_PASS_DEF_GENERALIZELINALGNAMEDOPSPASS
@@ -41,17 +45,15 @@ static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
4145
4246 if (!llvm::all_of (convDims.strides ,
4347 [](int64_t element) { return element == 1 ; })) {
48+ LDBG (" conv not foldable: non-unit strides" );
4449 return false ;
4550 }
4651
47- // Dont generalize depthwise convolutions.
48- if (!convDims.depth .empty ()) {
49- return false ;
50- }
51-
52- // Dont generalize pooling operations. For pooling ops, the input/output
53- // channel size will be categorized as the additional batch dimension
52+ // Dont generalize pooling operations or depthwise convolutions. For pooling
53+ // ops, the input/output channel size will be categorized as the additional
54+ // batch dimension.
5455 if (convDims.outputChannel .empty () || convDims.inputChannel .empty ()) {
56+ LDBG (" conv not foldable: missing input or output channel dims" );
5557 return false ;
5658 }
5759
@@ -60,6 +62,7 @@ static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
6062 auto filterShapeType = llvm::dyn_cast<RankedTensorType>(
6163 linalgOp.getDpsInputOperand (kFilterInputIdx )->get ().getType ());
6264 if (!filterShapeType) {
65+ LDBG (" conv not foldable: filter shape not ranked tensor" );
6366 return false ;
6467 }
6568 auto filterShape = filterShapeType.getShape ();
@@ -68,6 +71,7 @@ static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
6871 std::optional<int64_t > maybeDim = filterMap.getResultPosition (
6972 getAffineDimExpr (filterLoop, filterMap.getContext ()));
7073 if (!maybeDim || filterShape[*maybeDim] != 1 ) {
74+ LDBG (" conv not foldable: non-unit filter dim" );
7175 return false ;
7276 }
7377 }
0 commit comments