Skip to content

Commit 1e37f7b

Browse files
authored
Merge pull request #394 from Xilinx/matthias.tiling_backport
Backport various improvements to fusion from upstream
2 parents 2f0e627 + a8317e1 commit 1e37f7b

File tree

24 files changed

+2113
-143
lines changed

24 files changed

+2113
-143
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,13 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157-
def Linalg_WinogradFilterTransformOp :
158-
Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
157+
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158+
[AllElementTypesMatch<["filter", "output"]>,
159+
DeclareOpInterfaceMethods<TilingInterface,
160+
["getIterationDomain",
161+
"getLoopIteratorTypes",
162+
"getResultTilePosition",
163+
"getTiledImplementation"]>]> {
159164
let summary = "Winograd filter transform operator";
160165
let description = [{
161166
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,11 +195,42 @@ def Linalg_WinogradFilterTransformOp :
190195
`outs` `(` $output `:` type($output) `)`
191196
`->` type($result)
192197
}];
198+
let extraClassDeclaration = [{
199+
ShapedType getFilterOperandType() {
200+
return cast<ShapedType>(getFilter().getType());
201+
}
202+
ShapedType getOutputOperandType() {
203+
return cast<ShapedType>(getOutput().getType());
204+
}
205+
int64_t getFilterOperandRank() {
206+
return getFilterOperandType().getRank();
207+
}
208+
int64_t getOutputOperandRank() {
209+
return getOutputOperandType().getRank();
210+
}
211+
int64_t getFilterFDim() {
212+
return 0;
213+
}
214+
int64_t getFilterHDim() {
215+
return 1;
216+
}
217+
int64_t getFilterWDim() {
218+
return 2;
219+
}
220+
int64_t getFilterCDim() {
221+
return 3;
222+
}
223+
}];
193224
let hasVerifier = 1;
194225
}
195226

196-
def Linalg_WinogradInputTransformOp :
197-
Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
227+
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
228+
[AllElementTypesMatch<["input", "output"]>,
229+
DeclareOpInterfaceMethods<TilingInterface,
230+
["getIterationDomain",
231+
"getLoopIteratorTypes",
232+
"getResultTilePosition",
233+
"getTiledImplementation"]>]> {
198234
let summary = "Winograd input transform operator";
199235
let description = [{
200236
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -229,11 +265,60 @@ def Linalg_WinogradInputTransformOp :
229265
`outs` `(` $output `:` type($output) `)`
230266
`->` type($result)
231267
}];
268+
let extraClassDeclaration = [{
269+
ShapedType getInputOperandType() {
270+
return cast<ShapedType>(getInput().getType());
271+
}
272+
ShapedType getOutputOperandType() {
273+
return cast<ShapedType>(getOutput().getType());
274+
}
275+
int64_t getInputOperandRank() {
276+
return getInputOperandType().getRank();
277+
}
278+
int64_t getOutputOperandRank() {
279+
return getOutputOperandType().getRank();
280+
}
281+
int64_t getInputNDim() {
282+
return 0;
283+
}
284+
int64_t getInputHDim() {
285+
return 1;
286+
}
287+
int64_t getInputWDim() {
288+
return 2;
289+
}
290+
int64_t getInputCDim() {
291+
return 3;
292+
}
293+
int64_t getOutputAlphaHDim() {
294+
return 0;
295+
}
296+
int64_t getOutputAlphaWDim() {
297+
return 1;
298+
}
299+
int64_t getOutputTileHDim() {
300+
return 2;
301+
}
302+
int64_t getOutputTileWDim() {
303+
return 3;
304+
}
305+
int64_t getOutputNDim() {
306+
return 4;
307+
}
308+
int64_t getOutputCDim() {
309+
return 5;
310+
}
311+
}];
232312
let hasVerifier = 1;
233313
}
234314

235-
def Linalg_WinogradOutputTransformOp :
236-
Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
315+
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
316+
[AllElementTypesMatch<["value", "output"]>,
317+
DeclareOpInterfaceMethods<TilingInterface,
318+
["getIterationDomain",
319+
"getLoopIteratorTypes",
320+
"getResultTilePosition",
321+
"getTiledImplementation"]>]> {
237322
let summary = "Winograd output transform operator";
238323
let description = [{
239324
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -268,6 +353,50 @@ def Linalg_WinogradOutputTransformOp :
268353
`outs` `(` $output `:` type($output) `)`
269354
`->` type($result)
270355
}];
356+
let extraClassDeclaration = [{
357+
ShapedType getValueOperandType() {
358+
return cast<ShapedType>(getValue().getType());
359+
}
360+
ShapedType getOutputOperandType() {
361+
return cast<ShapedType>(getOutput().getType());
362+
}
363+
int64_t getValueOperandRank() {
364+
return getValueOperandType().getRank();
365+
}
366+
int64_t getOutputOperandRank() {
367+
return getOutputOperandType().getRank();
368+
}
369+
int64_t getValueAlphaHDim() {
370+
return 0;
371+
}
372+
int64_t getValueAlphaWDim() {
373+
return 1;
374+
}
375+
int64_t getValueTileHDim() {
376+
return 2;
377+
}
378+
int64_t getValueTileWDim() {
379+
return 3;
380+
}
381+
int64_t getValueNDim() {
382+
return 4;
383+
}
384+
int64_t getValueFDim() {
385+
return 5;
386+
}
387+
int64_t getOutputNDim() {
388+
return 0;
389+
}
390+
int64_t getOutputHDim() {
391+
return 1;
392+
}
393+
int64_t getOutputWDim() {
394+
return 2;
395+
}
396+
int64_t getOutputFDim() {
397+
return 3;
398+
}
399+
}];
271400
let hasVerifier = 1;
272401
}
273402

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,18 +284,23 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
284284
let description = [{
285285
Tiles the operations pointed to by the target handle and fuses their
286286
producers greedily using the options provided as attributes.
287+
288+
If `apply_cleanup` is true then slice canonicalization is applied between
289+
fusion steps.
287290
}];
288291

289292
let arguments =
290293
(ins TransformHandleTypeInterface:$target,
291294
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
292-
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
295+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
296+
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
293297
let results = (outs TransformHandleTypeInterface:$transformed,
294298
Variadic<TransformHandleTypeInterface>:$loops);
295299

296300
let assemblyFormat = [{
297301
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
298-
attr-dict `:` functional-type(operands, results)
302+
(`apply_cleanup` `=` $apply_cleanup^)? attr-dict
303+
`:` functional-type(operands, results)
299304
}];
300305
let hasVerifier = 1;
301306
}
@@ -2697,4 +2702,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
26972702
}];
26982703
}
26992704

2705+
def DecomposeWinogradOp : Op<Transform_Dialect,
2706+
"structured.decompose_winograd_op",
2707+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2708+
TransformOpInterface, TransformEachOpTrait,
2709+
ReportTrackingListenerFailuresOpTrait]> {
2710+
let description = [{
2711+
Decompose winograd operations. It will convert filter, input and output
2712+
transform operations into a combination of scf, tensor, and linalg
2713+
equivalent operations. Before applying this transform operations, users
2714+
need to tile winograd transform operations into supported sizes.
2715+
2716+
#### Return modes:
2717+
2718+
This operation fails if `target` is unsupported. Otherwise, the operation
2719+
succeeds and returns a handle of the sequence that replaces the original
2720+
operations.
2721+
}];
2722+
2723+
let arguments = (ins TransformHandleTypeInterface:$target);
2724+
let results = (outs TransformHandleTypeInterface:$transformed);
2725+
2726+
let assemblyFormat =
2727+
"$target attr-dict `:` functional-type($target, results)";
2728+
2729+
let builders = [
2730+
OpBuilder<(ins "Value":$target)>
2731+
];
2732+
2733+
let extraClassDeclaration = [{
2734+
::mlir::DiagnosedSilenceableFailure applyToOne(
2735+
::mlir::transform::TransformRewriter &rewriter,
2736+
::mlir::Operation *target,
2737+
::mlir::transform::ApplyToEachResultList &results,
2738+
::mlir::transform::TransformState &state);
2739+
}];
2740+
}
2741+
27002742
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,63 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13161316
linalg::Conv2DNhwcFhwcOp op, int64_t m,
13171317
int64_t r);
13181318

1319+
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
1320+
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1321+
/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
1322+
/// the rewriting, we get
1323+
///
1324+
/// scf.for %f = lo_f to hi_f step 1
1325+
/// scf.for %c = lo_c to hi_c step 1
1326+
/// %extracted = extract filter<h x w> from filter<f x h x w x c>
1327+
/// %ret = linalg.matmul G, %extracted
1328+
/// %ret = linalg.matmul %ret, GT
1329+
/// %inserted = insert %ret into filter<h x w x c x f>
1330+
FailureOr<Operation *>
1331+
decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1332+
linalg::WinogradFilterTransformOp op);
1333+
1334+
/// Rewrite linalg.winograd_input_transform. The data layout of the input is
1335+
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1336+
/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
1337+
/// and tileW. After the rewriting, we get
1338+
///
1339+
/// scf.for %h = 0 to tileH step 1
1340+
/// scf.for %w = 0 to tileW step 1
1341+
/// scf.for %n = 0 to N step 1
1342+
/// scf.for %c = 0 to C step 1
1343+
/// %extracted = extract %extracted<alphaH x alphaW> from
1344+
/// %input<N x H x W x C>
1345+
/// at [%n, (%h x m), (%w x m), %c]
1346+
/// %ret = linalg.matmul BT, %extracted
1347+
/// %ret = linalg.matmul %ret, B
1348+
/// %inserted = insert %ret<alphaH x alphaW> into
1349+
/// %output<alphaH x alphaW x tileH x tileW x N x C>
1350+
/// at [0, 0, %h, %w, %n, %c]
1351+
FailureOr<Operation *>
1352+
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1353+
linalg::WinogradInputTransformOp op);
1354+
1355+
/// Rewrite linalg.winograd_output_transform. The data layout of the output is
1356+
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1357+
/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
1358+
/// and tileW. After the transformation, we get
1359+
///
1360+
/// scf.for %h = 0 to tileH step 1
1361+
/// scf.for %w = 0 to tileW step 1
1362+
/// scf.for %n = 0 to N step 1
1363+
/// scf.for %f = 0 to F step 1
1364+
/// %extracted = extract %extracted<alphaH x alphaW> from
1365+
/// %input<alphaH x alphaW x tileH x tileW x N x F>
1366+
/// at [0, 0, %h, %w, %n, %f]
1367+
/// %ret = linalg.matmul AT, %extracted
1368+
/// %ret = linalg.matmul %ret, A
1369+
/// %inserted = insert %ret<alphaH x alphaW> into
1370+
/// output<N x H x W x F>
1371+
/// at [%n, (%h x m), (%w x m), %f]
1372+
FailureOr<Operation *>
1373+
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1374+
linalg::WinogradOutputTransformOp op);
1375+
13191376
//===----------------------------------------------------------------------===//
13201377
// Rewrite patterns wrapping transformations.
13211378
// TODO: every single such pattern should be a close to noop wrapper around a

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
178178
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
179179
/// controls whether to omit the partial/boundary tile condition check in
180180
/// cases where we statically know that it is unnecessary.
181-
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
182-
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
183-
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
184-
ArrayRef<OpFoldResult> subShapeSizes,
185-
bool omitPartialTileCheck);
181+
Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
182+
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
183+
ArrayRef<OpFoldResult> lbs,
184+
ArrayRef<OpFoldResult> ubs,
185+
ArrayRef<OpFoldResult> subShapeSizes,
186+
bool omitPartialTileCheck);
186187

187188
/// Creates extract_slice/subview ops for all `valuesToTile` of the given
188189
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop

0 commit comments

Comments
 (0)