@@ -52,6 +52,12 @@ using namespace mlir::linalg;
5252#define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
5353#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
5454
55+ // Forward declaration of Conv1DGenerator and its validator
56+ namespace {
57+ struct Conv1DGenerator ;
58+ bool validateConv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp);
59+ } // namespace
60+
5561// / Try to vectorize `convOp` as a convolution.
5662static FailureOr<Operation *>
5763vectorizeConvolution (RewriterBase &rewriter, LinalgOp convOp,
@@ -1991,14 +1997,17 @@ static LogicalResult vectorizeLinalgOpPrecondition(
19911997 // features. But we will still need stride/dilation attributes that will be
19921998 // annoying to reverse-engineer...
19931999 if (isa<ConvolutionOpInterface>(linalgOp.getOperation ())) {
1994- // Check if it is 2d+ convolution. If it is, return failure because we don't
2000+ // Create a dummy rewriter first, a rewriter is not required for
2001+ // validation
2002+ IRRewriter dummyBuilder (linalgOp.getContext ());
2003+ // Check if we can successfully construct a 1d convolution generator.
2004+ // For example, if it is 2d+ convolution, return failure because we don't
19952005 // support it. To use this pass on a 2d+ convolution, it should have already
19962006 // been decomposed to 1d convolution via
19972007 // DecomposeConvolutionToLowerDimOpsPass.
1998- if (linalgOp.getNumParallelLoops () >= 4 ) {
1999- LDBG (" precondition failed: Regular 2d+ convolutions not supported.\n " );
2008+ if (!validateConv1DGenerator (dummyBuilder, linalgOp))
20002009 return failure ();
2001- }
2010+
20022011 return success ();
20032012 }
20042013
@@ -3197,6 +3206,8 @@ struct Conv1DGenerator
31973206 valid = true ;
31983207 }
31993208
3209+ bool isValid () { return valid; }
3210+
32003211 // / Generate a vector implementation for:
32013212 // / ```
32023213 // / Op def: ( w, kw )
@@ -3907,6 +3918,13 @@ struct Conv1DGenerator
39073918 }
39083919 }
39093920};
3921+
3922+ // Helper function to construct Conv1DGenerator
3923+ bool validateConv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp) {
3924+ Conv1DGenerator conv1dGen (rewriter, linalgOp);
3925+ return conv1dGen.isValid ();
3926+ }
3927+
39103928} // namespace
39113929
39123930// / Helper function to vectorize a LinalgOp with convolution semantics.
0 commit comments