@@ -1990,8 +1990,18 @@ static LogicalResult vectorizeLinalgOpPrecondition(
19901990 // TODO: isaConvolutionOpInterface that can also infer from generic
19911991 // features. But we will still need stride/dilation attributes that will be
19921992 // annoying to reverse-engineer...
1993- if (isa<ConvolutionOpInterface>(linalgOp.getOperation ()))
1993+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation ())) {
1994+ // Check if it is 2d+ convolution. If it is, return failure because we don't
1995+ // support it. To use this pass on a 2d+ convolution, it should have already
1996+ // been decomposed to 1d convolution via
1997+ // DecomposeConvolutionToLowerDimOpsPass.
1998+ if (linalgOp.getNumParallelLoops () >= 4 ) {
1999+ LDBG (" precondition failed: Regular 2d+ convolutions not supported.\n " );
2000+ return failure ();
2001+ }
19942002 return success ();
2003+ }
2004+
19952005 // TODO: the common vector shape is equal to the static loop sizes only when
19962006 // all indexing maps are projected permutations. For convs and stencils the
19972007 // logic will need to evolve.
@@ -3929,9 +3939,11 @@ static FailureOr<Operation *> vectorizeConvolution(
39293939 if (!inputVecSizes.empty ()) {
39303940 // Only use the input vector size corresponding to the channel dim. Other
39313941 // vector dims will be inferred from the Ops.
3932- assert ((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3933- isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3934- " Not a 1D depthwise conv!" );
3942+ if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
3943+ !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
3944+ return rewriter.notifyMatchFailure (
3945+ op, " Unexpected convolution: expected 1D depthwise conv" );
3946+ }
39353947 size_t chDimIdx =
39363948 TypeSwitch<Operation *, size_t >(op)
39373949 .Case <linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2 ; })
0 commit comments