@@ -3063,8 +3063,11 @@ LogicalResult WinogradInputTransformOp::verify() {
30633063 int m = getM ();
30643064 int r = getR ();
30653065 int64_t tileSize = m + r - 1 ;
3066- bool leftTransform = inputH != 1 ;
3067- bool rightTransform = inputW != 1 ;
3066+
3067+ auto outputType = cast<ShapedType>(getOutput ().getType ());
3068+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3069+ bool leftTransform = outputShape[getOutputAlphaHDim ()] != 1 ;
3070+ bool rightTransform = outputShape[getOutputAlphaWDim ()] != 1 ;
30683071
30693072 SmallVector<int64_t > expectedOutputShape (6 , inputH);
30703073 if (ShapedType::isDynamic (inputH)) {
@@ -3073,21 +3076,19 @@ LogicalResult WinogradInputTransformOp::verify() {
30733076 } else {
30743077 expectedOutputShape[getOutputAlphaHDim ()] = leftTransform ? tileSize : 1 ;
30753078 expectedOutputShape[getOutputTileHDim ()] =
3076- leftTransform ? (inputH - (r - 1 )) / m : 1 ;
3079+ leftTransform ? (inputH - (r - 1 )) / m : inputH ;
30773080 }
30783081 if (ShapedType::isDynamic (inputW)) {
30793082 expectedOutputShape[getOutputAlphaWDim ()] = tileSize;
30803083 expectedOutputShape[getOutputTileWDim ()] = ShapedType::kDynamic ;
30813084 } else {
30823085 expectedOutputShape[getOutputAlphaWDim ()] = rightTransform ? tileSize : 1 ;
30833086 expectedOutputShape[getOutputTileWDim ()] =
3084- rightTransform ? (inputW - (r - 1 )) / m : 1 ;
3087+ rightTransform ? (inputW - (r - 1 )) / m : inputW ;
30853088 }
30863089 expectedOutputShape[getOutputNDim ()] = inputShape[getInputNDim ()];
30873090 expectedOutputShape[getOutputCDim ()] = inputShape[getInputCDim ()];
30883091
3089- auto outputType = cast<ShapedType>(getOutput ().getType ());
3090- ArrayRef<int64_t > outputShape = outputType.getShape ();
30913092 if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
30923093 return emitOpError (" the output shape is not expected" );
30933094 }
@@ -3124,15 +3125,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
31243125 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
31253126 SmallVector<OpFoldResult> &resultSizes) {
31263127 IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
3127- ShapedType inputType = getInputOperandType ();
3128- ArrayRef<int64_t > inputShape = inputType.getShape ();
3129- int64_t inputH = inputShape[getInputHDim ()];
3130- int64_t inputW = inputShape[getInputWDim ()];
3128+ ShapedType outputType = getOutputOperandType ();
3129+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3130+ int64_t outputAlphaH = outputShape[getOutputAlphaHDim ()];
3131+ int64_t outputAlphaW = outputShape[getOutputAlphaWDim ()];
3132+
31313133 int64_t m = getM ();
31323134 int64_t r = getR ();
31333135 int64_t alpha = m + r - 1 ;
3134- int64_t alphaH = inputH != 1 ? alpha : 1 ;
3135- int64_t alphaW = inputW != 1 ? alpha : 1 ;
3136+ int64_t alphaH = outputAlphaH != 1 ? alpha : 1 ;
3137+ int64_t alphaW = outputAlphaW != 1 ? alpha : 1 ;
3138+
31363139 IntegerAttr alphaHAttr = builder.getI64IntegerAttr (alphaH);
31373140 IntegerAttr alphaWAttr = builder.getI64IntegerAttr (alphaW);
31383141
@@ -3165,6 +3168,11 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31653168 int64_t m = getM ();
31663169 int64_t r = getR ();
31673170
3171+ ShapedType outputType = getOutputOperandType ();
3172+ ArrayRef<int64_t > outputShape = outputType.getShape ();
3173+ int64_t alphaH = outputShape[getOutputAlphaHDim ()];
3174+ int64_t alphaW = outputShape[getOutputAlphaWDim ()];
3175+
31683176 Location loc = getLoc ();
31693177 MLIRContext *context = builder.getContext ();
31703178 auto offsetAffineMap =
@@ -3190,9 +3198,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31903198 sliceOffsets.append (
31913199 {offsets[getOutputNDim ()], offsetH, offsetW, offsets[getOutputCDim ()]});
31923200 OpFoldResult sizeH =
3193- inputH != 1 ? OpFoldResult (mappedSizeH) : OpFoldResult (oneAttr);
3201+ alphaH != 1 ? OpFoldResult (mappedSizeH) : OpFoldResult (oneAttr);
31943202 OpFoldResult sizeW =
3195- inputW != 1 ? OpFoldResult (mappedSizeW) : OpFoldResult (oneAttr);
3203+ alphaW != 1 ? OpFoldResult (mappedSizeW) : OpFoldResult (oneAttr);
31963204 sliceSizes.append (
31973205 {sizes[getOutputNDim ()], sizeH, sizeW, sizes[getOutputCDim ()]});
31983206 int64_t inputRank = getInputOperandRank ();
0 commit comments