@@ -1940,7 +1940,10 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
19401940}
19411941
19421942namespace {
1943- bool isCastOfBlockArgument (Operation *op) {
1943+ enum class ConvOperationKind { Conv, Pool };
1944+ } // namespace
1945+
1946+ static bool isCastOfBlockArgument (Operation *op) {
19441947 return isa<CastOpInterface>(op) && op->getNumOperands () == 1 &&
19451948 isa<BlockArgument>(op->getOperand (0 ));
19461949}
@@ -1952,8 +1955,8 @@ bool isCastOfBlockArgument(Operation *op) {
19521955// block arguments or extension of block arguments.
19531956// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
19541957// must be block arguments or extension of block arguments.
1955- enum OperKind { Conv, Pool };
1956- bool getOperKind (Operation *reduceOp, OperKind &oper ) {
1958+ static std::optional<ConvOperationKind>
1959+ getConvOperationKind (Operation *reduceOp) {
19571960 int numBlockArguments =
19581961 llvm::count_if (reduceOp->getOperands (), llvm::IsaPred<BlockArgument>);
19591962
@@ -1966,33 +1969,34 @@ bool getOperKind(Operation *reduceOp, OperKind &oper) {
19661969 auto feedValIt = llvm::find_if_not (reduceOp->getOperands (),
19671970 llvm::IsaPred<BlockArgument>);
19681971 Operation *feedOp = (*feedValIt).getDefiningOp ();
1969- // llvm::outs() << "feedOp: " << *feedOp << "\n";
19701972 if (isCastOfBlockArgument (feedOp)) {
1971- oper = Pool;
1972- } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1973- (isa<arith::AndIOp>(feedOp) &&
1974- feedOp->getResultTypes ()[0 ].isInteger (1 ))) &&
1975- llvm::all_of (feedOp->getOperands (), [](Value v) {
1976- if (isa<BlockArgument>(v))
1977- return true ;
1978- if (Operation *op = v.getDefiningOp ())
1979- return isCastOfBlockArgument (op);
1980- return false ;
1981- }))) {
1982- return false ;
1973+ return ConvOperationKind::Pool;
19831974 }
1984- return true ;
1975+
1976+ if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1977+ (isa<arith::AndIOp>(feedOp) &&
1978+ feedOp->getResultTypes ()[0 ].isInteger (1 ))) &&
1979+ llvm::all_of (feedOp->getOperands (), [](Value v) {
1980+ if (isa<BlockArgument>(v))
1981+ return true ;
1982+ if (Operation *op = v.getDefiningOp ())
1983+ return isCastOfBlockArgument (op);
1984+ return false ;
1985+ }))) {
1986+ return std::nullopt ;
1987+ }
1988+
1989+ return ConvOperationKind::Conv;
19851990 }
19861991 case 2 :
19871992 // Must be pooling
1988- oper = Pool;
1989- return true ;
1993+ return ConvOperationKind::Pool;
19901994 default :
1991- return false ;
1995+ return std:: nullopt ;
19921996 }
19931997}
19941998
1995- bool isSupportedPoolKind (vector::CombiningKind kind) {
1999+ static bool isSupportedPoolKind (vector::CombiningKind kind) {
19962000 switch (kind) {
19972001 case vector::CombiningKind::ADD:
19982002 case vector::CombiningKind::MAXNUMF:
@@ -2008,7 +2012,6 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
20082012 return false ;
20092013 }
20102014}
2011- } // namespace
20122015
20132016static LogicalResult vectorizeConvOpPrecondition (linalg::LinalgOp convOp) {
20142017 if (convOp.getNumDpsInputs () != 2 || convOp.getNumDpsInits () != 1 )
@@ -2032,29 +2035,28 @@ static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
20322035 if (!reduceOp)
20332036 return failure ();
20342037
2035- OperKind oper = Conv ;
2036- if (!getOperKind (reduceOp, oper ))
2038+ auto maybeOper = getConvOperationKind (reduceOp) ;
2039+ if (!maybeOper. has_value ( ))
20372040 return failure ();
2041+
20382042 auto maybeKind = getCombinerOpKind (reduceOp);
20392043 // Typically convolution will have a `Add` CombiningKind but for i1 type it
20402044 // can get strength reduced to `OR` which is also supported. This strength
20412045 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
20422046 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
20432047 *maybeKind != vector::CombiningKind::OR) &&
2044- (oper != Pool || !isSupportedPoolKind (*maybeKind)))) {
2048+ (*maybeOper != ConvOperationKind::Pool ||
2049+ !isSupportedPoolKind (*maybeKind)))) {
20452050 return failure ();
20462051 }
20472052
20482053 auto rhsRank = rhsShapedType.getRank ();
2049- switch (oper) {
2050- case Conv:
2051- if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3 )
2052- return failure ();
2053- break ;
2054- case Pool:
2054+ if (*maybeOper == ConvOperationKind::Pool) {
20552055 if (rhsRank != 1 )
20562056 return failure ();
2057- break ;
2057+ } else {
2058+ if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3 )
2059+ return failure ();
20582060 }
20592061
20602062 return success ();
@@ -3238,7 +3240,7 @@ struct Conv1DGenerator
32383240 Operation *reduceOp = matchLinalgReduction (linalgOp.getDpsInitOperand (0 ));
32393241 redOp = reduceOp->getName ().getIdentifier ();
32403242
3241- setOperKind (reduceOp);
3243+ setConvOperationKind (reduceOp);
32423244
32433245 auto maybeKind = getCombinerOpKind (reduceOp);
32443246 reductionKind = maybeKind.value ();
@@ -3293,11 +3295,11 @@ struct Conv1DGenerator
32933295 // out{n, w, f}
32943296 bindShapeDims (resShapedType, nSize, wSize, fSize );
32953297 switch (oper) {
3296- case Conv:
3298+ case ConvOperationKind:: Conv:
32973299 // kernel{kw, c, f}
32983300 bindShapeDims (rhsShapedType, kwSize, cSize);
32993301 break ;
3300- case Pool:
3302+ case ConvOperationKind:: Pool:
33013303 // kernel{kw}
33023304 bindShapeDims (rhsShapedType, kwSize);
33033305 cSize = fSize ;
@@ -3311,10 +3313,10 @@ struct Conv1DGenerator
33113313 1 ,
33123314 cSize};
33133315 switch (oper) {
3314- case Conv:
3316+ case ConvOperationKind:: Conv:
33153317 rhsShape = {kwSize, cSize, fSize };
33163318 break ;
3317- case Pool:
3319+ case ConvOperationKind:: Pool:
33183320 rhsShape = {kwSize};
33193321 break ;
33203322 }
@@ -3324,11 +3326,11 @@ struct Conv1DGenerator
33243326 // out{n, f, w}
33253327 bindShapeDims (resShapedType, nSize, fSize , wSize);
33263328 switch (oper) {
3327- case Conv:
3329+ case ConvOperationKind:: Conv:
33283330 // kernel{f, c, kw}
33293331 bindShapeDims (rhsShapedType, fSize , cSize, kwSize);
33303332 break ;
3331- case Pool:
3333+ case ConvOperationKind:: Pool:
33323334 // kernel{kw}
33333335 bindShapeDims (rhsShapedType, kwSize);
33343336 cSize = fSize ;
@@ -3341,10 +3343,10 @@ struct Conv1DGenerator
33413343 ((wSize - 1 ) * strideW + 1 ) + ((kwSize - 1 ) * dilationW + 1 ) -
33423344 1 };
33433345 switch (oper) {
3344- case Conv:
3346+ case ConvOperationKind:: Conv:
33453347 rhsShape = {fSize , cSize, kwSize};
33463348 break ;
3347- case Pool:
3349+ case ConvOperationKind:: Pool:
33483350 rhsShape = {kwSize};
33493351 break ;
33503352 }
@@ -3376,7 +3378,7 @@ struct Conv1DGenerator
33763378 lhsPadding);
33773379 // This is needed only for Conv.
33783380 Value rhs = nullptr ;
3379- if (oper == Conv)
3381+ if (oper == ConvOperationKind:: Conv)
33803382 rhs = rewriter.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
33813383 rhsPadding);
33823384 Value res = rewriter.create <vector::TransferReadOp>(loc, resType, resShaped,
@@ -3399,7 +3401,7 @@ struct Conv1DGenerator
33993401 static constexpr std::array<int64_t , 3 > permRhs = {2 , 1 , 0 };
34003402
34013403 // This is needed only for Conv.
3402- if (oper == Conv)
3404+ if (oper == ConvOperationKind:: Conv)
34033405 rhs = rewriter.create <vector::TransposeOp>(loc, rhs, permRhs);
34043406 // nfw -> nwf
34053407 static constexpr std::array<int64_t , 3 > permRes = {0 , 2 , 1 };
@@ -3417,7 +3419,7 @@ struct Conv1DGenerator
34173419 kwSize, strideW, dilationW, wSizeStep,
34183420 isSingleChanneled);
34193421 // Do not do for pooling.
3420- if (oper == Conv)
3422+ if (oper == ConvOperationKind:: Conv)
34213423 rhsVals = extractConvFilterSlices (rewriter, loc, rhs, kwSize);
34223424 resVals = extractConvResultSlices (rewriter, loc, res, nSize, wSize, fSize ,
34233425 wSizeStep, isSingleChanneled);
@@ -3432,7 +3434,7 @@ struct Conv1DGenerator
34323434 for (int64_t kw = 0 ; kw < kwSize; ++kw) {
34333435 for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
34343436 switch (oper) {
3435- case Conv:
3437+ case ConvOperationKind:: Conv:
34363438 if (isSingleChanneled) {
34373439 resVals[w] = conv1dSliceAsOuterProduct (rewriter, loc,
34383440 lhsVals[linearIndex (kw, w)],
@@ -3443,7 +3445,7 @@ struct Conv1DGenerator
34433445 rhsVals[kw], resVals[w]);
34443446 }
34453447 break ;
3446- case Pool:
3448+ case ConvOperationKind:: Pool:
34473449 resVals[w] = pool1dSlice (rewriter, loc, lhsVals[linearIndex (kw, w)],
34483450 resVals[w]);
34493451 break ;
@@ -3898,7 +3900,7 @@ struct Conv1DGenerator
38983900 }
38993901
39003902private:
3901- OperKind oper = Conv;
3903+ ConvOperationKind oper = ConvOperationKind:: Conv;
39023904 StringAttr redOp;
39033905 StringAttr poolExtOp;
39043906 bool isPoolExt = false ;
@@ -3908,7 +3910,7 @@ struct Conv1DGenerator
39083910 vector::CombiningKind reductionKind;
39093911
39103912 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3911- void setOperKind (Operation *reduceOp) {
3913+ void setConvOperationKind (Operation *reduceOp) {
39123914 int numBlockArguments =
39133915 llvm::count_if (reduceOp->getOperands (), llvm::IsaPred<BlockArgument>);
39143916 if (numBlockArguments == 1 ) {
@@ -3920,17 +3922,17 @@ struct Conv1DGenerator
39203922 llvm::IsaPred<BlockArgument>);
39213923 Operation *feedOp = (*feedValIt).getDefiningOp ();
39223924 if (isCastOfBlockArgument (feedOp)) {
3923- oper = Pool;
3925+ oper = ConvOperationKind:: Pool;
39243926 isPoolExt = true ;
39253927 poolExtOp = feedOp->getName ().getIdentifier ();
3926- } else {
3927- oper = Conv;
3928+ return ;
39283929 }
3929- } else {
3930- // Pooling.
3931- oper = Pool;
3932- isPoolExt = false ;
3930+ oper = ConvOperationKind::Conv;
3931+ return ;
39333932 }
3933+ // numBlockArugments == 2 and this is a pooling op.
3934+ oper = ConvOperationKind::Pool;
3935+ isPoolExt = false ;
39343936 }
39353937};
39363938} // namespace
0 commit comments