Skip to content

Commit 80474f7

Browse files
committed
Refactor Conv1DGenerator
1 parent a33b211 commit 80474f7

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3135,10 +3135,8 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
31353135
/// kw is unrolled, w is unrolled iff dilationW > 1.
31363136
struct Conv1DGenerator
31373137
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
3138-
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3139-
int dilationW)
3140-
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3141-
strideW(strideW), dilationW(dilationW) {
3138+
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3139+
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
31423140
// Determine whether `linalgOp` can be generated with this generator
31433141
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
31443142
return;
@@ -3185,6 +3183,16 @@ struct Conv1DGenerator
31853183
return;
31863184
break;
31873185
}
3186+
3187+
// The ConvolutionOpInterface gives us guarantees of existence for
3188+
// strides/dilations. However, we do not need to rely on those, we can
3189+
// simply use them if present, otherwise use the default and let the generic
3190+
// conv. matcher in the ConvGenerator succeed or fail.
3191+
auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3192+
auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3193+
strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3194+
dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3195+
31883196
// The op is now known to be valid.
31893197
valid = true;
31903198
}
@@ -3906,15 +3914,7 @@ struct Conv1DGenerator
39063914
static FailureOr<Operation *> vectorizeConvolution(
39073915
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
39083916
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3909-
// The ConvolutionOpInterface gives us guarantees of existence for
3910-
// strides/dilations. However, we do not need to rely on those, we can
3911-
// simply use them if present, otherwise use the default and let the generic
3912-
// conv. matcher in the ConvGenerator succeed or fail.
3913-
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3914-
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3915-
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3916-
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3917-
Conv1DGenerator e(rewriter, op, stride, dilation);
3917+
Conv1DGenerator e(rewriter, op);
39183918
auto res = e.generateNonChanneledConv();
39193919
if (succeeded(res))
39203920
return res;

0 commit comments

Comments
 (0)