diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ae04c2b6b2a5b..2dcd897330d1e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1939,6 +1939,130 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, return success(); } +namespace { +enum class ConvOperationKind { Conv, Pool }; +} // namespace + +static bool isCastOfBlockArgument(Operation *op) { + return isa(op) && op->getNumOperands() == 1 && + isa(op->getOperand(0)); +} + +// Returns the ConvOperationKind of the op using reduceOp of the generic +// payload. If it is neither a convolution nor a pooling, it returns +// std::nullopt. +// +// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction +// + yield) and rhs is not used) then it is the body of a pooling +// If conv, check for single `mul` predecessor. The `mul` operands must be +// block arguments or extension of block arguments. +// Otherwise, check for one or zero `ext` predecessor. The `ext` operands +// must be block arguments or extension of block arguments. +static std::optional +getConvOperationKind(Operation *reduceOp) { + int numBlockArguments = + llvm::count_if(reduceOp->getOperands(), llvm::IsaPred); + + switch (numBlockArguments) { + case 1: { + // Will be convolution if feeder is a MulOp. + // A strength reduced version of MulOp for i1 type is AndOp which is also + // supported. Otherwise, it can be pooling. This strength reduction logic + // is in `buildBinaryFn` helper in the Linalg dialect. + auto feedValIt = llvm::find_if_not(reduceOp->getOperands(), + llvm::IsaPred); + assert(feedValIt != reduceOp->operand_end() && + "Expected a non-block argument operand"); + Operation *feedOp = (*feedValIt).getDefiningOp(); + if (isCastOfBlockArgument(feedOp)) { + return ConvOperationKind::Pool; + } + + if (!((isa(feedOp) || + (isa(feedOp) && + feedOp->getResultTypes()[0].isInteger(1))) && + llvm::all_of(feedOp->getOperands(), [](Value v) { + if (isa(v)) + return true; + if (Operation *op = v.getDefiningOp()) + return isCastOfBlockArgument(op); + return false; + }))) { + return std::nullopt; + } + + return ConvOperationKind::Conv; + } + case 2: + // Must be pooling + return ConvOperationKind::Pool; + default: + return std::nullopt; + } +} + +static bool isSupportedPoolKind(vector::CombiningKind kind) { + switch (kind) { + case vector::CombiningKind::ADD: + case vector::CombiningKind::MAXNUMF: + case vector::CombiningKind::MAXIMUMF: + case vector::CombiningKind::MAXSI: + case vector::CombiningKind::MAXUI: + case vector::CombiningKind::MINNUMF: + case vector::CombiningKind::MINIMUMF: + case vector::CombiningKind::MINSI: + case vector::CombiningKind::MINUI: + return true; + default: + return false; + } +} + +static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) { + auto getOperandType = [&](auto operand) { + return dyn_cast((operand->get()).getType()); + }; + ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0)); + ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1)); + ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0)); + // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR + // (non-channeled convolution -> LHS and RHS both have single dimensions). + // Note that this also ensures 2D and 3D convolutions are rejected. + if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) && + (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1)) + return failure(); + + Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0)); + if (!reduceOp) + return failure(); + + auto maybeOper = getConvOperationKind(reduceOp); + if (!maybeOper.has_value()) + return failure(); + + auto maybeKind = getCombinerOpKind(reduceOp); + // Typically convolution will have a `Add` CombiningKind but for i1 type it + // can get strength reduced to `OR` which is also supported. This strength + // reduction logic is in `buildBinaryFn` helper in the Linalg dialect. + if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD && + *maybeKind != vector::CombiningKind::OR) && + (*maybeOper != ConvOperationKind::Pool || + !isSupportedPoolKind(*maybeKind)))) { + return failure(); + } + + auto rhsRank = rhsShapedType.getRank(); + if (*maybeOper == ConvOperationKind::Pool) { + if (rhsRank != 1) + return failure(); + } else { + if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) + return failure(); + } + + return success(); +} + static LogicalResult vectorizeLinalgOpPrecondition( LinalgOp linalgOp, ArrayRef inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { @@ -1991,7 +2115,8 @@ static LogicalResult vectorizeLinalgOpPrecondition( // features. But we will still need stride/dilation attributes that will be // annoying to reverse-engineer... if (isa(linalgOp.getOperation())) - return success(); + return vectorizeConvOpPrecondition(linalgOp); + // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. @@ -3067,28 +3192,6 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { } namespace { -bool isCastOfBlockArgument(Operation *op) { - return isa(op) && op->getNumOperands() == 1 && - isa(op->getOperand(0)); -} - -bool isSupportedPoolKind(vector::CombiningKind kind) { - switch (kind) { - case vector::CombiningKind::ADD: - case vector::CombiningKind::MAXNUMF: - case vector::CombiningKind::MAXIMUMF: - case vector::CombiningKind::MAXSI: - case vector::CombiningKind::MAXUI: - case vector::CombiningKind::MINNUMF: - case vector::CombiningKind::MINIMUMF: - case vector::CombiningKind::MINSI: - case vector::CombiningKind::MINUI: - return true; - default: - return false; - } -} - /// Generate a vector implementation for either: /// ``` /// Op def: ( w, kw ) @@ -3125,58 +3228,32 @@ bool isSupportedPoolKind(vector::CombiningKind kind) { /// kw is unrolled, w is unrolled iff dilationW > 1. struct Conv1DGenerator : public StructuredGenerator { - Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW, - int dilationW) - : StructuredGenerator(rewriter, linalgOp), - strideW(strideW), dilationW(dilationW) { - // Determine whether `linalgOp` can be generated with this generator - if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) - return; + Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) + : StructuredGenerator(rewriter, linalgOp) { + lhsShaped = linalgOp.getDpsInputOperand(0)->get(); rhsShaped = linalgOp.getDpsInputOperand(1)->get(); resShaped = linalgOp.getDpsInitOperand(0)->get(); lhsShapedType = dyn_cast(lhsShaped.getType()); rhsShapedType = dyn_cast(rhsShaped.getType()); resShapedType = dyn_cast(resShaped.getType()); - if (!lhsShapedType || !rhsShapedType || !resShapedType) - return; - // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR - // (non-channeled convolution -> LHS and RHS both have single dimensions). - if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) && - (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1)) - return; Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); - if (!reduceOp) - return; redOp = reduceOp->getName().getIdentifier(); - if (!setOperKind(reduceOp)) - return; + setConvOperationKind(reduceOp); + auto maybeKind = getCombinerOpKind(reduceOp); - // Typically convolution will have a `Add` CombiningKind but for i1 type it - // can get strength reduced to `OR` which is also supported. This strength - // reduction logic is in `buildBinaryFn` helper in the Linalg dialect. - if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD && - *maybeKind != vector::CombiningKind::OR) && - (oper != Pool || !isSupportedPoolKind(*maybeKind)))) { - return; - } reductionKind = maybeKind.value(); - auto rhsRank = rhsShapedType.getRank(); - switch (oper) { - case Conv: - if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) - return; - break; - case Pool: - if (rhsRank != 1) - return; - break; - } - // The op is now known to be valid. - valid = true; + // The ConvolutionOpInterface gives us guarantees of existence for + // strides/dilations. However, we do not need to rely on those, we can + // simply use them if present, otherwise use the default and let the generic + // conv. matcher in the ConvGenerator succeed or fail. + auto strides = linalgOp->getAttrOfType("strides"); + auto dilations = linalgOp->getAttrOfType("dilations"); + strideW = strides ? *strides.getValues().begin() : 1; + dilationW = dilations ? *dilations.getValues().begin() : 1; } /// Generate a vector implementation for: @@ -3198,9 +3275,6 @@ struct Conv1DGenerator /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is /// > 1. FailureOr conv(Conv1DOpOrder conv1DOpOrder) { - if (!valid) - return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool"); - int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W); @@ -3222,11 +3296,11 @@ struct Conv1DGenerator // out{n, w, f} bindShapeDims(resShapedType, nSize, wSize, fSize); switch (oper) { - case Conv: + case ConvOperationKind::Conv: // kernel{kw, c, f} bindShapeDims(rhsShapedType, kwSize, cSize); break; - case Pool: + case ConvOperationKind::Pool: // kernel{kw} bindShapeDims(rhsShapedType, kwSize); cSize = fSize; @@ -3240,10 +3314,10 @@ struct Conv1DGenerator 1, cSize}; switch (oper) { - case Conv: + case ConvOperationKind::Conv: rhsShape = {kwSize, cSize, fSize}; break; - case Pool: + case ConvOperationKind::Pool: rhsShape = {kwSize}; break; } @@ -3253,11 +3327,11 @@ struct Conv1DGenerator // out{n, f, w} bindShapeDims(resShapedType, nSize, fSize, wSize); switch (oper) { - case Conv: + case ConvOperationKind::Conv: // kernel{f, c, kw} bindShapeDims(rhsShapedType, fSize, cSize, kwSize); break; - case Pool: + case ConvOperationKind::Pool: // kernel{kw} bindShapeDims(rhsShapedType, kwSize); cSize = fSize; @@ -3270,10 +3344,10 @@ struct Conv1DGenerator ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1}; switch (oper) { - case Conv: + case ConvOperationKind::Conv: rhsShape = {fSize, cSize, kwSize}; break; - case Pool: + case ConvOperationKind::Pool: rhsShape = {kwSize}; break; } @@ -3305,7 +3379,7 @@ struct Conv1DGenerator lhsPadding); // This is needed only for Conv. Value rhs = nullptr; - if (oper == Conv) + if (oper == ConvOperationKind::Conv) rhs = rewriter.create(loc, rhsType, rhsShaped, rhsPadding); Value res = rewriter.create(loc, resType, resShaped, @@ -3328,7 +3402,7 @@ struct Conv1DGenerator static constexpr std::array permRhs = {2, 1, 0}; // This is needed only for Conv. - if (oper == Conv) + if (oper == ConvOperationKind::Conv) rhs = rewriter.create(loc, rhs, permRhs); // nfw -> nwf static constexpr std::array permRes = {0, 2, 1}; @@ -3346,7 +3420,7 @@ struct Conv1DGenerator kwSize, strideW, dilationW, wSizeStep, isSingleChanneled); // Do not do for pooling. - if (oper == Conv) + if (oper == ConvOperationKind::Conv) rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize); resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, wSizeStep, isSingleChanneled); @@ -3361,7 +3435,7 @@ struct Conv1DGenerator for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { switch (oper) { - case Conv: + case ConvOperationKind::Conv: if (isSingleChanneled) { resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc, lhsVals[linearIndex(kw, w)], @@ -3372,7 +3446,7 @@ struct Conv1DGenerator rhsVals[kw], resVals[w]); } break; - case Pool: + case ConvOperationKind::Pool: resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], resVals[w]); break; @@ -3483,9 +3557,6 @@ struct Conv1DGenerator FailureOr depthwiseConv(uint64_t channelDimVecSize, bool channelDimScalableFlag, bool flatten) { - if (!valid) - return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv"); - bool scalableChDim = false; bool useMasking = false; int64_t nSize, wSize, cSize, kwSize; @@ -3830,9 +3901,7 @@ struct Conv1DGenerator } private: - enum OperKind { Conv, Pool }; - bool valid = false; - OperKind oper = Conv; + ConvOperationKind oper = ConvOperationKind::Conv; StringAttr redOp; StringAttr poolExtOp; bool isPoolExt = false; @@ -3842,18 +3911,10 @@ struct Conv1DGenerator vector::CombiningKind reductionKind; // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. - // Returns true iff it is a valid conv/pooling op. - // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction - // + yield) and rhs is not used) then it is the body of a pooling - // If conv, check for single `mul` predecessor. The `mul` operands must be - // block arguments or extension of block arguments. - // Otherwise, check for one or zero `ext` predecessor. The `ext` operands - // must be block arguments or extension of block arguments. - bool setOperKind(Operation *reduceOp) { + void setConvOperationKind(Operation *reduceOp) { int numBlockArguments = llvm::count_if(reduceOp->getOperands(), llvm::IsaPred); - switch (numBlockArguments) { - case 1: { + if (numBlockArguments == 1) { // Will be convolution if feeder is a MulOp. // A strength reduced version of MulOp for i1 type is AndOp which is also // supported. Otherwise, it can be pooling. This strength reduction logic @@ -3862,31 +3923,17 @@ struct Conv1DGenerator llvm::IsaPred); Operation *feedOp = (*feedValIt).getDefiningOp(); if (isCastOfBlockArgument(feedOp)) { - oper = Pool; + oper = ConvOperationKind::Pool; isPoolExt = true; poolExtOp = feedOp->getName().getIdentifier(); - } else if (!((isa(feedOp) || - (isa(feedOp) && - feedOp->getResultTypes()[0].isInteger(1))) && - llvm::all_of(feedOp->getOperands(), [](Value v) { - if (isa(v)) - return true; - if (Operation *op = v.getDefiningOp()) - return isCastOfBlockArgument(op); - return false; - }))) { - return false; + return; } - return true; - } - case 2: - // Must be pooling - oper = Pool; - isPoolExt = false; - return true; - default: - return false; + oper = ConvOperationKind::Conv; + return; } + // numBlockArugments == 2 and this is a pooling op. + oper = ConvOperationKind::Pool; + isPoolExt = false; } }; } // namespace @@ -3896,28 +3943,20 @@ struct Conv1DGenerator static FailureOr vectorizeConvolution( RewriterBase &rewriter, LinalgOp op, ArrayRef inputVecSizes, ArrayRef inputScalableVecDims, bool flatten1DDepthwiseConv) { - // The ConvolutionOpInterface gives us guarantees of existence for - // strides/dilations. However, we do not need to rely on those, we can - // simply use them if present, otherwise use the default and let the generic - // conv. matcher in the ConvGenerator succeed or fail. - auto strides = op->getAttrOfType("strides"); - auto dilations = op->getAttrOfType("dilations"); - auto stride = strides ? *strides.getValues().begin() : 1; - auto dilation = dilations ? *dilations.getValues().begin() : 1; - Conv1DGenerator e(rewriter, op, stride, dilation); - auto res = e.generateNonChanneledConv(); + Conv1DGenerator conv1dGen(rewriter, op); + auto res = conv1dGen.generateNonChanneledConv(); if (succeeded(res)) return res; - res = e.generateNwcConv(); + res = conv1dGen.generateNwcConv(); if (succeeded(res)) return res; - res = e.generateNcwConv(); + res = conv1dGen.generateNcwConv(); if (succeeded(res)) return res; - res = e.generateNwcPooling(); + res = conv1dGen.generateNwcPooling(); if (succeeded(res)) return res; - res = e.generateNcwPooling(); + res = conv1dGen.generateNcwPooling(); if (succeeded(res)) return res; @@ -3940,8 +3979,8 @@ static FailureOr vectorizeConvolution( vecChDimSize = inputVecSizes[chDimIdx]; vecChDimScalableFlag = inputScalableVecDims[chDimIdx]; } - return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag, - flatten1DDepthwiseConv); + return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag, + flatten1DDepthwiseConv); } struct VectorizeConvolution : public OpInterfaceRewritePattern { diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir index 8f3b199145ce0..2d1f0191eb798 100644 --- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir @@ -112,6 +112,55 @@ module attributes {transform.with_named_sequence} { // ----- +func.func @conv2d_nchw_fchw(%input: tensor<1x5x8x8xf32>, %filter: tensor<4x5x3x3xf32>, %output: tensor<1x4x6x6xf32>) { + // expected-error @+1 {{Attempted to vectorize, but failed}} + linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x5x8x8xf32>, tensor<4x5x3x3xf32>) outs(%output : tensor<1x4x6x6xf32>) -> tensor<1x4x6x6xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @conv2d_nhwc_fhwc(%input: tensor<1x8x8x5xf32>, %filter: tensor<4x3x3x5xf32>, %output: tensor<1x6x6x4xf32>) { + // expected-error @+1 {{Attempted to vectorize, but failed}} + linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x8x8x5xf32>, tensor<4x3x3x5xf32>) outs(%output : tensor<1x6x6x4xf32>) -> tensor<1x6x6x4xf32> + return +} + + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @conv3d_ncdhw_fcdhw(%input: tensor<1x5x8x8x8xf32>, %filter: tensor<4x5x3x3x3xf32>, %output: tensor<1x4x6x6x6xf32>) { + // expected-error @+1 {{Attempted to vectorize, but failed}} + linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : vector<3xi64>, strides = dense<1> : vector<3xi64>} ins(%input, %filter : tensor<1x5x8x8x8xf32>, tensor<4x5x3x3x3xf32>) outs(%output : tensor<1x4x6x6x6xf32>) -> tensor<1x4x6x6x6xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_3d_ncdhw_fcdhw"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + +// ----- + func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> { %pad = arith.constant 0.000000e+00 : f32 // expected-error @+1 {{Attempted to vectorize, but failed}}