@@ -3060,8 +3060,11 @@ LogicalResult WinogradInputTransformOp::verify() {
30603060 int m = getM ();
30613061 int r = getR ();
30623062 int64_t tileSize = m + r - 1 ;
3063- bool leftTransform = inputH != 1 ;
3064- bool rightTransform = inputW != 1 ;
3063+
3064+ auto outputType = cast<ShapedType>(getOutput ().getType ());
3065+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3066+ bool leftTransform = outputShape[getOutputAlphaHDim ()] != 1 ;
3067+ bool rightTransform = outputShape[getOutputAlphaWDim ()] != 1 ;
30653068
30663069 SmallVector<int64_t > expectedOutputShape (6 , inputH);
30673070 if (ShapedType::isDynamic (inputH)) {
@@ -3070,21 +3073,19 @@ LogicalResult WinogradInputTransformOp::verify() {
30703073 } else {
30713074 expectedOutputShape[getOutputAlphaHDim ()] = leftTransform ? tileSize : 1 ;
30723075 expectedOutputShape[getOutputTileHDim ()] =
3073- leftTransform ? (inputH - (r - 1 )) / m : 1 ;
3076+ leftTransform ? (inputH - (r - 1 )) / m : inputH ;
30743077 }
30753078 if (ShapedType::isDynamic (inputW)) {
30763079 expectedOutputShape[getOutputAlphaWDim ()] = tileSize;
30773080 expectedOutputShape[getOutputTileWDim ()] = ShapedType::kDynamic ;
30783081 } else {
30793082 expectedOutputShape[getOutputAlphaWDim ()] = rightTransform ? tileSize : 1 ;
30803083 expectedOutputShape[getOutputTileWDim ()] =
3081- rightTransform ? (inputW - (r - 1 )) / m : 1 ;
3084+ rightTransform ? (inputW - (r - 1 )) / m : inputW ;
30823085 }
30833086 expectedOutputShape[getOutputNDim ()] = inputShape[getInputNDim ()];
30843087 expectedOutputShape[getOutputCDim ()] = inputShape[getInputCDim ()];
30853088
3086- auto outputType = cast<ShapedType>(getOutput ().getType ());
3087- ArrayRef<int64_t > outputShape = outputType.getShape ();
30883089 if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
30893090 return emitOpError (" the output shape is not expected" );
30903091 }
@@ -3121,15 +3122,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
31213122 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
31223123 SmallVector<OpFoldResult> &resultSizes) {
31233124 IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3124- ShapedType inputType = getInputOperandType ();
3125- ArrayRef<int64_t > inputShape = inputType.getShape ();
3126- int64_t inputH = inputShape[getInputHDim ()];
3127- int64_t inputW = inputShape[getInputWDim ()];
3125+ ShapedType outputType = getOutputOperandType ();
3126+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3127+ int64_t outputAlphaH = outputShape[getOutputAlphaHDim ()];
3128+ int64_t outputAlphaW = outputShape[getOutputAlphaWDim ()];
3129+
31283130 int64_t m = getM ();
31293131 int64_t r = getR ();
31303132 int64_t alpha = m + r - 1 ;
3131- int64_t alphaH = inputH != 1 ? alpha : 1 ;
3132- int64_t alphaW = inputW != 1 ? alpha : 1 ;
3133+ int64_t alphaH = outputAlphaH != 1 ? alpha : 1 ;
3134+ int64_t alphaW = outputAlphaW != 1 ? alpha : 1 ;
3135+
31333136 IntegerAttr alphaHAttr = builder.getI64IntegerAttr (alphaH);
31343137 IntegerAttr alphaWAttr = builder.getI64IntegerAttr (alphaW);
31353138
@@ -3154,22 +3157,26 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31543157 ArrayRef<OpFoldResult> offsets,
31553158 ArrayRef<OpFoldResult> sizes) {
31563159 IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
3157- IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3158- ShapedType inputType = getInputOperandType ();
3159- ArrayRef<int64_t > inputShape = inputType.getShape ();
3160- int64_t inputH = inputShape[getInputHDim ()];
3161- int64_t inputW = inputShape[getInputWDim ()];
31623160 int64_t m = getM ();
31633161 int64_t r = getR ();
31643162
3163+ ShapedType outputType = getOutputOperandType ();
3164+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3165+ int64_t alphaH = outputShape[getOutputAlphaHDim ()];
3166+ int64_t alphaW = outputShape[getOutputAlphaWDim ()];
3167+
31653168 Location loc = getLoc ();
31663169 MLIRContext *context = builder.getContext ();
3170+ auto identityAffineMap =
3171+ AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 )}, context);
31673172 auto offsetAffineMap =
31683173 AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
31693174 Value mappedOffsetH = affine::makeComposedAffineApply (
3170- builder, loc, offsetAffineMap, offsets[getOutputTileHDim ()]);
3175+ builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3176+ offsets[getOutputTileHDim ()]);
31713177 Value mappedOffsetW = affine::makeComposedAffineApply (
3172- builder, loc, offsetAffineMap, offsets[getOutputTileWDim ()]);
3178+ builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3179+ offsets[getOutputTileWDim ()]);
31733180 auto sizeAffineMap = AffineMap::get (
31743181 1 , 0 , {builder.getAffineDimExpr (0 ) * m + (r - 1 )}, context);
31753182 Value mappedSizeH = affine::makeComposedAffineApply (
@@ -3180,16 +3187,14 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31803187 SmallVector<Value> tiledOperands;
31813188 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
31823189
3183- OpFoldResult offsetH =
3184- inputH != 1 ? OpFoldResult (mappedOffsetH) : OpFoldResult (zeroAttr);
3185- OpFoldResult offsetW =
3186- inputW != 1 ? OpFoldResult (mappedOffsetW) : OpFoldResult (zeroAttr);
3190+ OpFoldResult offsetH = OpFoldResult (mappedOffsetH);
3191+ OpFoldResult offsetW = OpFoldResult (mappedOffsetW);
31873192 sliceOffsets.append (
31883193 {offsets[getOutputNDim ()], offsetH, offsetW, offsets[getOutputCDim ()]});
31893194 OpFoldResult sizeH =
3190- inputH != 1 ? OpFoldResult (mappedSizeH) : OpFoldResult (oneAttr);
3195+ alphaH != 1 ? OpFoldResult (mappedSizeH) : OpFoldResult (oneAttr);
31913196 OpFoldResult sizeW =
3192- inputW != 1 ? OpFoldResult (mappedSizeW) : OpFoldResult (oneAttr);
3197+ alphaW != 1 ? OpFoldResult (mappedSizeW) : OpFoldResult (oneAttr);
31933198 sliceSizes.append (
31943199 {sizes[getOutputNDim ()], sizeH, sizeW, sizes[getOutputCDim ()]});
31953200 int64_t inputRank = getInputOperandRank ();
@@ -3297,28 +3302,29 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
32973302
32983303 Location loc = getLoc ();
32993304 MLIRContext *context = builder.getContext ();
3305+ auto identityAffineMap =
3306+ AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 )}, context);
33003307 auto affineMap =
33013308 AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
33023309
3310+ ShapedType valueType = getValueOperandType ();
3311+ ArrayRef<int64_t > valueShape = valueType.getShape ();
3312+ int64_t valueH = valueShape[0 ];
3313+ int64_t valueW = valueShape[1 ];
33033314 Value mappedOffsetH = affine::makeComposedAffineApply (
3304- builder, loc, affineMap, offsets[getValueTileHDim ()]);
3315+ builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3316+ offsets[getValueTileHDim ()]);
33053317 Value mappedOffsetW = affine::makeComposedAffineApply (
3306- builder, loc, affineMap, offsets[getValueTileWDim ()]);
3318+ builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3319+ offsets[getValueTileWDim ()]);
33073320 Value mappedSizeH = affine::makeComposedAffineApply (
33083321 builder, loc, affineMap, sizes[getValueTileHDim ()]);
33093322 Value mappedSizeW = affine::makeComposedAffineApply (
33103323 builder, loc, affineMap, sizes[getValueTileWDim ()]);
33113324
3312- ShapedType valueType = getValueOperandType ();
3313- ArrayRef<int64_t > valueShape = valueType.getShape ();
3314- int64_t valueH = valueShape[0 ];
3315- int64_t valueW = valueShape[1 ];
33163325 IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
3317- IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3318- OpFoldResult offsetH =
3319- valueH != 1 ? OpFoldResult (mappedOffsetH) : OpFoldResult (zeroAttr);
3320- OpFoldResult offsetW =
3321- valueW != 1 ? OpFoldResult (mappedOffsetW) : OpFoldResult (zeroAttr);
3326+ OpFoldResult offsetH = OpFoldResult (mappedOffsetH);
3327+ OpFoldResult offsetW = OpFoldResult (mappedOffsetW);
33223328 OpFoldResult sizeH =
33233329 valueH != 1 ? OpFoldResult (mappedSizeH) : OpFoldResult (oneAttr);
33243330 OpFoldResult sizeW =
0 commit comments