@@ -3135,10 +3135,8 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
31353135// / kw is unrolled, w is unrolled iff dilationW > 1.
31363136struct Conv1DGenerator
31373137 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3138- Conv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3139- int dilationW)
3140- : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3141- strideW (strideW), dilationW(dilationW) {
3138+ Conv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp)
3139+ : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
31423140 // Determine whether `linalgOp` can be generated with this generator
31433141 if (linalgOp.getNumDpsInputs () != 2 || linalgOp.getNumDpsInits () != 1 )
31443142 return ;
@@ -3185,6 +3183,16 @@ struct Conv1DGenerator
31853183 return ;
31863184 break ;
31873185 }
3186+
3187+ // The ConvolutionOpInterface gives us guarantees of existence for
3188+ // strides/dilations. However, we do not need to rely on those, we can
3189+ // simply use them if present, otherwise use the default and let the generic
3190+ // conv. matcher in the ConvGenerator succeed or fail.
3191+ auto strides = linalgOp->getAttrOfType <DenseIntElementsAttr>(" strides" );
3192+ auto dilations = linalgOp->getAttrOfType <DenseIntElementsAttr>(" dilations" );
3193+ strideW = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
3194+ dilationW = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
3195+
31883196 // The op is now known to be valid.
31893197 valid = true ;
31903198 }
@@ -3906,15 +3914,7 @@ struct Conv1DGenerator
39063914static FailureOr<Operation *> vectorizeConvolution (
39073915 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t > inputVecSizes,
39083916 ArrayRef<bool > inputScalableVecDims, bool flatten1DDepthwiseConv) {
3909- // The ConvolutionOpInterface gives us guarantees of existence for
3910- // strides/dilations. However, we do not need to rely on those, we can
3911- // simply use them if present, otherwise use the default and let the generic
3912- // conv. matcher in the ConvGenerator succeed or fail.
3913- auto strides = op->getAttrOfType <DenseIntElementsAttr>(" strides" );
3914- auto dilations = op->getAttrOfType <DenseIntElementsAttr>(" dilations" );
3915- auto stride = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
3916- auto dilation = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
3917- Conv1DGenerator e (rewriter, op, stride, dilation);
3917+ Conv1DGenerator e (rewriter, op);
39183918 auto res = e.generateNonChanneledConv ();
39193919 if (succeeded (res))
39203920 return res;
0 commit comments