@@ -52,12 +52,6 @@ using namespace mlir::linalg;
5252#define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
5353#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
5454
55- // Forward declaration of Conv1DGenerator and its validator
56- namespace {
57- struct Conv1DGenerator ;
58- bool validateConv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp);
59- } // namespace
60-
6155// / Try to vectorize `convOp` as a convolution.
6256static FailureOr<Operation *>
6357vectorizeConvolution (RewriterBase &rewriter, LinalgOp convOp,
@@ -1945,6 +1939,22 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
19451939 return success ();
19461940}
19471941
1942+ static LogicalResult vectorizeConvOpPrecondition (linalg::LinalgOp convOp) {
1943+ // We only support 1D convolutions, reject all other cases.
1944+ if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
1945+ linalg::Conv2DNchwFchwOp>(convOp)) {
1946+ LDBG (" 2D convolutions are not supported\n " );
1947+ return failure ();
1948+ }
1949+
1950+ if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
1951+ LDBG (" 3D convolutions are not supported\n " );
1952+ return failure ();
1953+ }
1954+
1955+ return success ();
1956+ }
1957+
19481958static LogicalResult vectorizeLinalgOpPrecondition (
19491959 LinalgOp linalgOp, ArrayRef<int64_t > inputVectorSizes,
19501960 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -1996,20 +2006,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
19962006 // TODO: isaConvolutionOpInterface that can also infer from generic
19972007 // features. But we will still need stride/dilation attributes that will be
19982008 // annoying to reverse-engineer...
1999- if (isa<ConvolutionOpInterface>(linalgOp.getOperation ())) {
2000- // Create a dummy rewriter first, a rewriter is not required for
2001- // validation
2002- IRRewriter dummyBuilder (linalgOp.getContext ());
2003- // Check if we can successfully construct a 1d convolution generator.
2004- // For example, if it is 2d+ convolution, return failure because we don't
2005- // support it. To use this pass on a 2d+ convolution, it should have already
2006- // been decomposed to 1d convolution via
2007- // DecomposeConvolutionToLowerDimOpsPass.
2008- if (!validateConv1DGenerator (dummyBuilder, linalgOp))
2009- return failure ();
2010-
2011- return success ();
2012- }
2009+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation ()))
2010+ return vectorizeConvOpPrecondition (linalgOp);
20132011
20142012 // TODO: the common vector shape is equal to the static loop sizes only when
20152013 // all indexing maps are projected permutations. For convs and stencils the
@@ -3918,34 +3916,28 @@ struct Conv1DGenerator
39183916 }
39193917 }
39203918};
3921-
3922- // Helper function to construct Conv1DGenerator
3923- bool validateConv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp) {
3924- Conv1DGenerator conv1dGen (rewriter, linalgOp);
3925- return conv1dGen.isValid ();
3926- }
3927-
39283919} // namespace
39293920
39303921// / Helper function to vectorize a LinalgOp with convolution semantics.
39313922// TODO: extend the generic vectorization to support windows and drop this.
39323923static FailureOr<Operation *> vectorizeConvolution (
39333924 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t > inputVecSizes,
39343925 ArrayRef<bool > inputScalableVecDims, bool flatten1DDepthwiseConv) {
3935- Conv1DGenerator e (rewriter, op);
3936- auto res = e.generateNonChanneledConv ();
3926+ Conv1DGenerator conv1dGen (rewriter, op);
3927+ assert (conv1dGen.isValid () && " Conv1DGenerator failed" );
3928+ auto res = conv1dGen.generateNonChanneledConv ();
39373929 if (succeeded (res))
39383930 return res;
3939- res = e .generateNwcConv ();
3931+ res = conv1dGen .generateNwcConv ();
39403932 if (succeeded (res))
39413933 return res;
3942- res = e .generateNcwConv ();
3934+ res = conv1dGen .generateNcwConv ();
39433935 if (succeeded (res))
39443936 return res;
3945- res = e .generateNwcPooling ();
3937+ res = conv1dGen .generateNwcPooling ();
39463938 if (succeeded (res))
39473939 return res;
3948- res = e .generateNcwPooling ();
3940+ res = conv1dGen .generateNcwPooling ();
39493941 if (succeeded (res))
39503942 return res;
39513943
@@ -3957,11 +3949,9 @@ static FailureOr<Operation *> vectorizeConvolution(
39573949 if (!inputVecSizes.empty ()) {
39583950 // Only use the input vector size corresponding to the channel dim. Other
39593951 // vector dims will be inferred from the Ops.
3960- if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
3961- !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
3962- return rewriter.notifyMatchFailure (
3963- op, " Unexpected convolution: expected 1D depthwise conv" );
3964- }
3952+ assert ((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3953+ isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3954+ " Not a 1D depthwise conv!" );
39653955 size_t chDimIdx =
39663956 TypeSwitch<Operation *, size_t >(op)
39673957 .Case <linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2 ; })
@@ -3970,8 +3960,8 @@ static FailureOr<Operation *> vectorizeConvolution(
39703960 vecChDimSize = inputVecSizes[chDimIdx];
39713961 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
39723962 }
3973- return e .generateDilatedConv (vecChDimSize, vecChDimScalableFlag,
3974- flatten1DDepthwiseConv);
3963+ return conv1dGen .generateDilatedConv (vecChDimSize, vecChDimScalableFlag,
3964+ flatten1DDepthwiseConv);
39753965}
39763966
39773967struct VectorizeConvolution : public OpInterfaceRewritePattern <LinalgOp> {
0 commit comments