@@ -1971,6 +1971,8 @@ getConvOperationKind(Operation *reduceOp) {
19711971 // is in `buildBinaryFn` helper in the Linalg dialect.
19721972 auto feedValIt = llvm::find_if_not (reduceOp->getOperands (),
19731973 llvm::IsaPred<BlockArgument>);
1974+ assert (feedValIt != reduceOp->operand_end () &&
1975+ " Expected a non-block argument operand" );
19741976 Operation *feedOp = (*feedValIt).getDefiningOp ();
19751977 if (isCastOfBlockArgument (feedOp)) {
19761978 return ConvOperationKind::Pool;
@@ -2017,17 +2019,12 @@ static bool isSupportedPoolKind(vector::CombiningKind kind) {
20172019}
20182020
20192021static LogicalResult vectorizeConvOpPrecondition (linalg::LinalgOp convOp) {
2020- if (convOp.getNumDpsInputs () != 2 || convOp.getNumDpsInits () != 1 )
2021- return failure ();
2022-
2023- auto lhsShaped = convOp.getDpsInputOperand (0 )->get ();
2024- auto rhsShaped = convOp.getDpsInputOperand (1 )->get ();
2025- auto resShaped = convOp.getDpsInitOperand (0 )->get ();
2026- auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType ());
2027- auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType ());
2028- auto resShapedType = dyn_cast<ShapedType>(resShaped.getType ());
2029- if (!lhsShapedType || !rhsShapedType || !resShapedType)
2030- return failure ();
2022+ auto getOperandType = [&](auto operand) {
2023+ return dyn_cast<ShapedType>((operand->get ()).getType ());
2024+ };
2025+ ShapedType lhsShapedType = getOperandType (convOp.getDpsInputOperand (0 ));
2026+ ShapedType rhsShapedType = getOperandType (convOp.getDpsInputOperand (1 ));
2027+ ShapedType resShapedType = getOperandType (convOp.getDpsInitOperand (0 ));
20312028 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
20322029 // (non-channeled convolution -> LHS and RHS both have single dimensions).
20332030 if ((lhsShapedType.getRank () != 3 || resShapedType.getRank () != 3 ) &&
0 commit comments