@@ -1939,19 +1939,124 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
19391939 return success ();
19401940}
19411941
1942+ namespace {
1943+ bool isCastOfBlockArgument (Operation *op) {
1944+ return isa<CastOpInterface>(op) && op->getNumOperands () == 1 &&
1945+ isa<BlockArgument>(op->getOperand (0 ));
1946+ }
1947+
1948+ // Returns true iff it is a valid conv/pooling op.
1949+ // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
1950+ // + yield) and rhs is not used) then it is the body of a pooling
1951+ // If conv, check for single `mul` predecessor. The `mul` operands must be
1952+ // block arguments or extension of block arguments.
1953+ // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
1954+ // must be block arguments or extension of block arguments.
1955+ enum OperKind { Conv, Pool };
1956+ bool getOperKind (Operation *reduceOp, OperKind &oper) {
1957+ int numBlockArguments =
1958+ llvm::count_if (reduceOp->getOperands (), llvm::IsaPred<BlockArgument>);
1959+
1960+ switch (numBlockArguments) {
1961+ case 1 : {
1962+ // Will be convolution if feeder is a MulOp.
1963+ // A strength reduced version of MulOp for i1 type is AndOp which is also
1964+ // supported. Otherwise, it can be pooling. This strength reduction logic
1965+ // is in `buildBinaryFn` helper in the Linalg dialect.
1966+ auto feedValIt = llvm::find_if_not (reduceOp->getOperands (),
1967+ llvm::IsaPred<BlockArgument>);
1968+ Operation *feedOp = (*feedValIt).getDefiningOp ();
1969+ // llvm::outs() << "feedOp: " << *feedOp << "\n";
1970+ if (isCastOfBlockArgument (feedOp)) {
1971+ oper = Pool;
1972+ } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1973+ (isa<arith::AndIOp>(feedOp) &&
1974+ feedOp->getResultTypes ()[0 ].isInteger (1 ))) &&
1975+ llvm::all_of (feedOp->getOperands (), [](Value v) {
1976+ if (isa<BlockArgument>(v))
1977+ return true ;
1978+ if (Operation *op = v.getDefiningOp ())
1979+ return isCastOfBlockArgument (op);
1980+ return false ;
1981+ }))) {
1982+ return false ;
1983+ }
1984+ return true ;
1985+ }
1986+ case 2 :
1987+ // Must be pooling
1988+ oper = Pool;
1989+ return true ;
1990+ default :
1991+ return false ;
1992+ }
1993+ }
1994+
1995+ bool isSupportedPoolKind (vector::CombiningKind kind) {
1996+ switch (kind) {
1997+ case vector::CombiningKind::ADD:
1998+ case vector::CombiningKind::MAXNUMF:
1999+ case vector::CombiningKind::MAXIMUMF:
2000+ case vector::CombiningKind::MAXSI:
2001+ case vector::CombiningKind::MAXUI:
2002+ case vector::CombiningKind::MINNUMF:
2003+ case vector::CombiningKind::MINIMUMF:
2004+ case vector::CombiningKind::MINSI:
2005+ case vector::CombiningKind::MINUI:
2006+ return true ;
2007+ default :
2008+ return false ;
2009+ }
2010+ }
2011+ } // namespace
2012+
19422013static LogicalResult vectorizeConvOpPrecondition (linalg::LinalgOp convOp) {
1943- // We only support 1D convolutions, reject all other cases.
1944- if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
1945- linalg::Conv2DNchwFchwOp>(convOp)) {
1946- LDBG (" 2D convolutions are not supported\n " );
2014+ if (convOp.getNumDpsInputs () != 2 || convOp.getNumDpsInits () != 1 )
2015+ return failure ();
2016+
2017+ auto lhsShaped = convOp.getDpsInputOperand (0 )->get ();
2018+ auto rhsShaped = convOp.getDpsInputOperand (1 )->get ();
2019+ auto resShaped = convOp.getDpsInitOperand (0 )->get ();
2020+ auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType ());
2021+ auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType ());
2022+ auto resShapedType = dyn_cast<ShapedType>(resShaped.getType ());
2023+ if (!lhsShapedType || !rhsShapedType || !resShapedType)
2024+ return failure ();
2025+ // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2026+ // (non-channeled convolution -> LHS and RHS both have single dimensions).
2027+ if ((lhsShapedType.getRank () != 3 || resShapedType.getRank () != 3 ) &&
2028+ (lhsShapedType.getRank () != 1 || resShapedType.getRank () != 1 ))
19472029 return failure ();
1948- }
19492030
1950- if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
1951- LDBG (" 3D convolutions are not supported\n " );
2031+ Operation *reduceOp = matchLinalgReduction (convOp.getDpsInitOperand (0 ));
2032+ if (!reduceOp)
2033+ return failure ();
2034+
2035+ OperKind oper = Conv;
2036+ if (!getOperKind (reduceOp, oper))
2037+ return failure ();
2038+ auto maybeKind = getCombinerOpKind (reduceOp);
2039+ // Typically convolution will have a `Add` CombiningKind but for i1 type it
2040+ // can get strength reduced to `OR` which is also supported. This strength
2041+ // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2042+ if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2043+ *maybeKind != vector::CombiningKind::OR) &&
2044+ (oper != Pool || !isSupportedPoolKind (*maybeKind)))) {
19522045 return failure ();
19532046 }
19542047
2048+ auto rhsRank = rhsShapedType.getRank ();
2049+ switch (oper) {
2050+ case Conv:
2051+ if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3 )
2052+ return failure ();
2053+ break ;
2054+ case Pool:
2055+ if (rhsRank != 1 )
2056+ return failure ();
2057+ break ;
2058+ }
2059+
19552060 return success ();
19562061}
19572062
@@ -3084,28 +3189,6 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
30843189}
30853190
30863191namespace {
3087- bool isCastOfBlockArgument (Operation *op) {
3088- return isa<CastOpInterface>(op) && op->getNumOperands () == 1 &&
3089- isa<BlockArgument>(op->getOperand (0 ));
3090- }
3091-
3092- bool isSupportedPoolKind (vector::CombiningKind kind) {
3093- switch (kind) {
3094- case vector::CombiningKind::ADD:
3095- case vector::CombiningKind::MAXNUMF:
3096- case vector::CombiningKind::MAXIMUMF:
3097- case vector::CombiningKind::MAXSI:
3098- case vector::CombiningKind::MAXUI:
3099- case vector::CombiningKind::MINNUMF:
3100- case vector::CombiningKind::MINIMUMF:
3101- case vector::CombiningKind::MINSI:
3102- case vector::CombiningKind::MINUI:
3103- return true ;
3104- default :
3105- return false ;
3106- }
3107- }
3108-
31093192// / Generate a vector implementation for either:
31103193// / ```
31113194// / Op def: ( w, kw )
@@ -3144,53 +3227,22 @@ struct Conv1DGenerator
31443227 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
31453228 Conv1DGenerator (RewriterBase &rewriter, LinalgOp linalgOp)
31463229 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3147- // Determine whether `linalgOp` can be generated with this generator
3148- if (linalgOp.getNumDpsInputs () != 2 || linalgOp.getNumDpsInits () != 1 )
3149- return ;
3230+
31503231 lhsShaped = linalgOp.getDpsInputOperand (0 )->get ();
31513232 rhsShaped = linalgOp.getDpsInputOperand (1 )->get ();
31523233 resShaped = linalgOp.getDpsInitOperand (0 )->get ();
31533234 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType ());
31543235 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType ());
31553236 resShapedType = dyn_cast<ShapedType>(resShaped.getType ());
3156- if (!lhsShapedType || !rhsShapedType || !resShapedType)
3157- return ;
3158- // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
3159- // (non-channeled convolution -> LHS and RHS both have single dimensions).
3160- if ((lhsShapedType.getRank () != 3 || resShapedType.getRank () != 3 ) &&
3161- (lhsShapedType.getRank () != 1 || resShapedType.getRank () != 1 ))
3162- return ;
31633237
31643238 Operation *reduceOp = matchLinalgReduction (linalgOp.getDpsInitOperand (0 ));
3165- if (!reduceOp)
3166- return ;
31673239 redOp = reduceOp->getName ().getIdentifier ();
31683240
3169- if (! setOperKind (reduceOp))
3170- return ;
3241+ setOperKind (reduceOp);
3242+
31713243 auto maybeKind = getCombinerOpKind (reduceOp);
3172- // Typically convolution will have a `Add` CombiningKind but for i1 type it
3173- // can get strength reduced to `OR` which is also supported. This strength
3174- // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
3175- if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3176- *maybeKind != vector::CombiningKind::OR) &&
3177- (oper != Pool || !isSupportedPoolKind (*maybeKind)))) {
3178- return ;
3179- }
31803244 reductionKind = maybeKind.value ();
31813245
3182- auto rhsRank = rhsShapedType.getRank ();
3183- switch (oper) {
3184- case Conv:
3185- if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3 )
3186- return ;
3187- break ;
3188- case Pool:
3189- if (rhsRank != 1 )
3190- return ;
3191- break ;
3192- }
3193-
31943246 // The ConvolutionOpInterface gives us guarantees of existence for
31953247 // strides/dilations. However, we do not need to rely on those, we can
31963248 // simply use them if present, otherwise use the default and let the generic
@@ -3199,13 +3251,8 @@ struct Conv1DGenerator
31993251 auto dilations = linalgOp->getAttrOfType <DenseIntElementsAttr>(" dilations" );
32003252 strideW = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
32013253 dilationW = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
3202-
3203- // The op is now known to be valid.
3204- valid = true ;
32053254 }
32063255
3207- bool isValid () { return valid; }
3208-
32093256 // / Generate a vector implementation for:
32103257 // / ```
32113258 // / Op def: ( w, kw )
@@ -3225,9 +3272,6 @@ struct Conv1DGenerator
32253272 // / TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
32263273 // / > 1.
32273274 FailureOr<Operation *> conv (Conv1DOpOrder conv1DOpOrder) {
3228- if (!valid)
3229- return rewriter.notifyMatchFailure (op, " unvectorizable 1-D conv/pool" );
3230-
32313275 int64_t nSize, wSize, cSize, kwSize, fSize ;
32323276 SmallVector<int64_t , 3 > lhsShape, rhsShape, resShape;
32333277 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
@@ -3510,9 +3554,6 @@ struct Conv1DGenerator
35103554 FailureOr<Operation *> depthwiseConv (uint64_t channelDimVecSize,
35113555 bool channelDimScalableFlag,
35123556 bool flatten) {
3513- if (!valid)
3514- return rewriter.notifyMatchFailure (op, " unvectorizable depthwise conv" );
3515-
35163557 bool scalableChDim = false ;
35173558 bool useMasking = false ;
35183559 int64_t nSize, wSize, cSize, kwSize;
@@ -3857,8 +3898,6 @@ struct Conv1DGenerator
38573898 }
38583899
38593900private:
3860- enum OperKind { Conv, Pool };
3861- bool valid = false ;
38623901 OperKind oper = Conv;
38633902 StringAttr redOp;
38643903 StringAttr poolExtOp;
@@ -3869,18 +3908,10 @@ struct Conv1DGenerator
38693908 vector::CombiningKind reductionKind;
38703909
38713910 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3872- // Returns true iff it is a valid conv/pooling op.
3873- // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3874- // + yield) and rhs is not used) then it is the body of a pooling
3875- // If conv, check for single `mul` predecessor. The `mul` operands must be
3876- // block arguments or extension of block arguments.
3877- // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3878- // must be block arguments or extension of block arguments.
3879- bool setOperKind (Operation *reduceOp) {
3911+ void setOperKind (Operation *reduceOp) {
38803912 int numBlockArguments =
38813913 llvm::count_if (reduceOp->getOperands (), llvm::IsaPred<BlockArgument>);
3882- switch (numBlockArguments) {
3883- case 1 : {
3914+ if (numBlockArguments == 1 ) {
38843915 // Will be convolution if feeder is a MulOp.
38853916 // A strength reduced version of MulOp for i1 type is AndOp which is also
38863917 // supported. Otherwise, it can be pooling. This strength reduction logic
@@ -3892,27 +3923,13 @@ struct Conv1DGenerator
38923923 oper = Pool;
38933924 isPoolExt = true ;
38943925 poolExtOp = feedOp->getName ().getIdentifier ();
3895- } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3896- (isa<arith::AndIOp>(feedOp) &&
3897- feedOp->getResultTypes ()[0 ].isInteger (1 ))) &&
3898- llvm::all_of (feedOp->getOperands (), [](Value v) {
3899- if (isa<BlockArgument>(v))
3900- return true ;
3901- if (Operation *op = v.getDefiningOp ())
3902- return isCastOfBlockArgument (op);
3903- return false ;
3904- }))) {
3905- return false ;
3926+ } else {
3927+ oper = Conv;
39063928 }
3907- return true ;
3908- }
3909- case 2 :
3910- // Must be pooling
3929+ } else {
3930+ // Pooling.
39113931 oper = Pool;
39123932 isPoolExt = false ;
3913- return true ;
3914- default :
3915- return false ;
39163933 }
39173934 }
39183935};
@@ -3924,7 +3941,6 @@ static FailureOr<Operation *> vectorizeConvolution(
39243941 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t > inputVecSizes,
39253942 ArrayRef<bool > inputScalableVecDims, bool flatten1DDepthwiseConv) {
39263943 Conv1DGenerator conv1dGen (rewriter, op);
3927- assert (conv1dGen.isValid () && " Conv1DGenerator failed" );
39283944 auto res = conv1dGen.generateNonChanneledConv ();
39293945 if (succeeded (res))
39303946 return res;
0 commit comments