@@ -2989,8 +2989,9 @@ LogicalResult WinogradFilterTransformOp::verify() {
29892989 ArrayRef<int64_t > filterShape = filterType.getShape ();
29902990 int64_t filterH = filterShape[getFilterHDim ()];
29912991 int64_t filterW = filterShape[getFilterWDim ()];
2992- int64_t r = getR ();
2993- int64_t m = getM ();
2992+ WinogradConv2DFmr fmr = getFmr ();
2993+ int64_t m, r;
2994+ std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
29942995
29952996 if (filterH != r && filterH != 1 )
29962997 return emitOpError (" expect filter height either equals to r or 1" );
@@ -3046,8 +3047,9 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
30463047 ArrayRef<int64_t > filterShape = filterType.getShape ();
30473048 int64_t filterH = filterShape[getFilterHDim ()];
30483049 int64_t filterW = filterShape[getFilterWDim ()];
3049- int64_t m = getM ();
3050- int64_t r = getR ();
3050+ WinogradConv2DFmr fmr = getFmr ();
3051+ int64_t m, r;
3052+ std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
30513053 int64_t alpha = m + r - 1 ;
30523054 int64_t alphaH = filterH != 1 ? alpha : 1 ;
30533055 int64_t alphaW = filterW != 1 ? alpha : 1 ;
@@ -3124,8 +3126,9 @@ LogicalResult WinogradInputTransformOp::verify() {
31243126 ArrayRef<int64_t > inputShape = inputType.getShape ();
31253127 int64_t inputH = inputShape[getInputHDim ()];
31263128 int64_t inputW = inputShape[getInputWDim ()];
3127- int m = getM ();
3128- int r = getR ();
3129+ WinogradConv2DFmr fmr = getFmr ();
3130+ int64_t m, r;
3131+ std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
31293132 int64_t tileSize = m + r - 1 ;
31303133
31313134 auto outputType = cast<ShapedType>(getOutput ().getType ());
@@ -3194,8 +3197,9 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
31943197 int64_t outputAlphaH = outputShape[getOutputAlphaHDim ()];
31953198 int64_t outputAlphaW = outputShape[getOutputAlphaWDim ()];
31963199
3197- int64_t m = getM ();
3198- int64_t r = getR ();
3200+ WinogradConv2DFmr fmr = getFmr ();
3201+ int64_t m, r;
3202+ std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
31993203 int64_t alpha = m + r - 1 ;
32003204 int64_t alphaH = outputAlphaH != 1 ? alpha : 1 ;
32013205 int64_t alphaW = outputAlphaW != 1 ? alpha : 1 ;
@@ -3224,8 +3228,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
32243228 ArrayRef<OpFoldResult> offsets,
32253229 ArrayRef<OpFoldResult> sizes) {
32263230 IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
3227- int64_t m = getM ();
3228- int64_t r = getR ();
3231+ WinogradConv2DFmr fmr = getFmr ();
3232+ int64_t m, r;
3233+ std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
32293234
32303235 ShapedType outputType = getOutputOperandType ();
32313236 ArrayRef<int64_t > outputShape = outputType.getShape ();
@@ -3303,8 +3308,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
33033308 int64_t valueW = valueShape[getValueAlphaWDim ()];
33043309 int64_t valueTileH = valueShape[getValueTileHDim ()];
33053310 int64_t valueTileW = valueShape[getValueTileWDim ()];
3306- int m = getM ();
3307- int r = getR ();
3311+ WinogradConv2DFmr fmr = getFmr ();
3312+ int64_t m, r;
3313+ std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
33083314 bool leftTransform = valueH != 1 ;
33093315 bool rightTransform = valueW != 1 ;
33103316
@@ -3365,7 +3371,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
33653371 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
33663372 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
33673373 SmallVector<OpFoldResult> &resultSizes) {
3368- int64_t m = getM ();
3374+ WinogradConv2DFmr fmr = getFmr ();
3375+ int64_t m, r;
3376+ std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
33693377
33703378 Location loc = getLoc ();
33713379 MLIRContext *context = builder.getContext ();
@@ -3623,6 +3631,27 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
36233631namespace mlir {
36243632namespace linalg {
36253633
3634+ std::optional<WinogradConv2DFmr> getWinogradConv2DFmr (int64_t m, int64_t r) {
3635+ if (m == 2 && r == 3 )
3636+ return WinogradConv2DFmr::F_2_3;
3637+ if (m == 4 && r == 3 )
3638+ return WinogradConv2DFmr::F_4_3;
3639+ if (m == 2 && r == 5 )
3640+ return WinogradConv2DFmr::F_2_5;
3641+ return std::nullopt ;
3642+ }
3643+
3644+ std::pair<int64_t , int64_t > getFmrFromWinogradConv2DFmr (WinogradConv2DFmr fmr) {
3645+ switch (fmr) {
3646+ case WinogradConv2DFmr::F_2_3:
3647+ return {2 , 3 };
3648+ case WinogradConv2DFmr::F_4_3:
3649+ return {4 , 3 };
3650+ case WinogradConv2DFmr::F_2_5:
3651+ return {2 , 5 };
3652+ }
3653+ }
3654+
36263655// ===----------------------------------------------------------------------===//
36273656// MatMulOp
36283657// ===----------------------------------------------------------------------===//
0 commit comments