Skip to content

Commit c262828

Browse files
committed
[MLIR][Linalg] Fixes for Winograd decomposition and for tiling
The PR addresses issues with filers 1 x r and r x 1 and with tiling Signed-off-by: Dmitriy Smirnov <[email protected]>
1 parent 7d01a8f commit c262828

File tree

5 files changed

+186
-65
lines changed

5 files changed

+186
-65
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
155155
}
156156

157157
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158-
[AllElementTypesMatch<["filter", "output"]>,
158+
[AllElementTypesMatch<["filter", "output"]>, DestinationStyleOpInterface,
159159
DeclareOpInterfaceMethods<TilingInterface,
160160
["getIterationDomain",
161161
"getLoopIteratorTypes",
@@ -220,12 +220,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
220220
int64_t getFilterCDim() {
221221
return 3;
222222
}
223+
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
223224
}];
224225
let hasVerifier = 1;
225226
}
226227

227228
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
228-
[AllElementTypesMatch<["input", "output"]>,
229+
[AllElementTypesMatch<["input", "output"]>, DestinationStyleOpInterface,
229230
DeclareOpInterfaceMethods<TilingInterface,
230231
["getIterationDomain",
231232
"getLoopIteratorTypes",
@@ -308,6 +309,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
308309
int64_t getOutputCDim() {
309310
return 5;
310311
}
312+
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
311313
}];
312314
let hasVerifier = 1;
313315
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3063,8 +3063,11 @@ LogicalResult WinogradInputTransformOp::verify() {
30633063
int m = getM();
30643064
int r = getR();
30653065
int64_t tileSize = m + r - 1;
3066-
bool leftTransform = inputH != 1;
3067-
bool rightTransform = inputW != 1;
3066+
3067+
auto outputType = cast<ShapedType>(getOutput().getType());
3068+
ArrayRef<int64_t> outputShape = outputType.getShape();
3069+
bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3070+
bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
30683071

30693072
SmallVector<int64_t> expectedOutputShape(6, inputH);
30703073
if (ShapedType::isDynamic(inputH)) {
@@ -3073,21 +3076,19 @@ LogicalResult WinogradInputTransformOp::verify() {
30733076
} else {
30743077
expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
30753078
expectedOutputShape[getOutputTileHDim()] =
3076-
leftTransform ? (inputH - (r - 1)) / m : 1;
3079+
leftTransform ? (inputH - (r - 1)) / m : inputH;
30773080
}
30783081
if (ShapedType::isDynamic(inputW)) {
30793082
expectedOutputShape[getOutputAlphaWDim()] = tileSize;
30803083
expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
30813084
} else {
30823085
expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
30833086
expectedOutputShape[getOutputTileWDim()] =
3084-
rightTransform ? (inputW - (r - 1)) / m : 1;
3087+
rightTransform ? (inputW - (r - 1)) / m : inputW;
30853088
}
30863089
expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
30873090
expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
30883091

3089-
auto outputType = cast<ShapedType>(getOutput().getType());
3090-
ArrayRef<int64_t> outputShape = outputType.getShape();
30913092
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
30923093
return emitOpError("the output shape is not expected");
30933094
}
@@ -3124,15 +3125,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
31243125
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
31253126
SmallVector<OpFoldResult> &resultSizes) {
31263127
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3127-
ShapedType inputType = getInputOperandType();
3128-
ArrayRef<int64_t> inputShape = inputType.getShape();
3129-
int64_t inputH = inputShape[getInputHDim()];
3130-
int64_t inputW = inputShape[getInputWDim()];
3128+
ShapedType outputType = getOutputOperandType();
3129+
ArrayRef<int64_t> outputShape = outputType.getShape();
3130+
int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3131+
int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3132+
31313133
int64_t m = getM();
31323134
int64_t r = getR();
31333135
int64_t alpha = m + r - 1;
3134-
int64_t alphaH = inputH != 1 ? alpha : 1;
3135-
int64_t alphaW = inputW != 1 ? alpha : 1;
3136+
int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3137+
int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3138+
31363139
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
31373140
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
31383141

@@ -3165,6 +3168,11 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31653168
int64_t m = getM();
31663169
int64_t r = getR();
31673170

3171+
ShapedType outputType = getOutputOperandType();
3172+
ArrayRef<int64_t> outputShape = outputType.getShape();
3173+
int64_t alphaH = outputShape[getOutputAlphaHDim()];
3174+
int64_t alphaW = outputShape[getOutputAlphaWDim()];
3175+
31683176
Location loc = getLoc();
31693177
MLIRContext *context = builder.getContext();
31703178
auto offsetAffineMap =
@@ -3190,9 +3198,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31903198
sliceOffsets.append(
31913199
{offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
31923200
OpFoldResult sizeH =
3193-
inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3201+
alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
31943202
OpFoldResult sizeW =
3195-
inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3203+
alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
31963204
sliceSizes.append(
31973205
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
31983206
int64_t inputRank = getInputOperandRank();

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -514,12 +514,14 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
514514
Value CIter = ivs[3];
515515

516516
auto context = builder.getContext();
517+
518+
auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
517519
auto affineMap =
518520
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
519-
Value heightOffset =
520-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
521-
Value widthOffset =
522-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
521+
Value heightOffset = builder.create<affine::AffineApplyOp>(
522+
loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
523+
Value widthOffset = builder.create<affine::AffineApplyOp>(
524+
loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
523525

524526
// Extract (H, W) from (N, H, W, C).
525527
auto extractInput =
@@ -753,12 +755,13 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
753755
Value zero = builder.create<arith::ConstantOp>(
754756
loc, rewriter.getZeroAttr(elementType));
755757

758+
auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
756759
auto affineMap =
757760
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
758-
Value heightOffset =
759-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
760-
Value widthOffset =
761-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
761+
Value heightOffset = builder.create<affine::AffineApplyOp>(
762+
loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
763+
Value widthOffset = builder.create<affine::AffineApplyOp>(
764+
loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
762765

763766
Value outInitVal =
764767
extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
@@ -1075,16 +1078,17 @@ FailureOr<Operation *>
10751078
decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
10761079
linalg::WinogradInputTransformOp op) {
10771080
Location loc = op.getLoc();
1078-
Value input = op.getInput();
1079-
auto inputType = cast<ShapedType>(input.getType());
1080-
auto inputShape = inputType.getShape();
1081-
int64_t inputH = inputShape[1];
1082-
int64_t inputW = inputShape[2];
1081+
Value output = op.getOutput();
1082+
auto outputType = cast<ShapedType>(output.getType());
1083+
auto outputShape = outputType.getShape();
1084+
1085+
int64_t outputH = outputShape[0];
1086+
int64_t outputW = outputShape[1];
10831087

10841088
// For F(m x 1, r x 1), we only need to do left side transform.
1085-
bool leftTransform = inputH != 1;
1089+
bool leftTransform = outputH != 1;
10861090
// For F(1 x m, 1 x r), we only need to do right side transform.
1087-
bool rightTransform = inputW != 1;
1091+
bool rightTransform = outputW != 1;
10881092
Value transformedInput =
10891093
inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
10901094
op.getR(), leftTransform, rightTransform);

0 commit comments

Comments
 (0)