Skip to content

Commit 9cbd032

Browse files
committed
fix comments
1 parent 0faf084 commit 9cbd032

File tree

1 file changed

+36
-11
lines changed

1 file changed

+36
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,12 +1100,26 @@ class FoldPadWithProducerReshapeOpByExpansion
11001100
ControlFusionFn controlFoldingReshapes;
11011101
};
11021102

1103-
/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
1103+
/// Pattern to move a tensor.expand_shape op with its producer tensor.pad op
11041104
/// by bubbling the expand_shape before the pad.
1105-
struct FoldReshapeWithProducerPadOpByExpansion
1105+
///
1106+
/// ```
1107+
/// BEFORE:
1108+
/// %padded = tensor.pad %input low[0, 1, 1] high[0, 1, 1]
1109+
/// tensor<512x256x256xf32> to tensor<512x258x258xf32>
1110+
/// %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]]
1111+
/// tensor<512x258x258xf32> to tensor<32x16x258x258xf32>
1112+
///
1113+
/// AFTER:
1114+
/// %expanded = tensor.expand_shape %input [[0, 1], [2], [3]]
1115+
/// tensor<512x256x256xf32> to tensor<32x16x256x256xf32>
1116+
/// %padded = tensor.pad %expanded low[0, 0, 1, 1] high[0, 0, 1, 1]
1117+
/// tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
1118+
/// ```
1119+
struct MoveReshapeWithProducerPadOpByExpansion
11061120
: public OpRewritePattern<tensor::ExpandShapeOp> {
11071121

1108-
FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1122+
MoveReshapeWithProducerPadOpByExpansion(MLIRContext *context,
11091123
ControlFusionFn foldReshapes,
11101124
PatternBenefit benefit = 1)
11111125
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
@@ -1181,12 +1195,26 @@ struct FoldReshapeWithProducerPadOpByExpansion
11811195
ControlFusionFn controlFoldingReshapes;
11821196
};
11831197

1184-
/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
1198+
/// Pattern to move a tensor.collapse_shape op with its producer tensor.pad op
11851199
/// by bubbling the collapse_shape before the pad.
1186-
struct FoldReshapeWithProducerPadOpByCollapsing
1200+
///
1201+
/// ```
1202+
/// BEFORE:
1203+
/// %padded = tensor.pad %input low[1, 1, 0] high[1, 1, 0]
1204+
/// tensor<32x16x256xf32> to tensor<34x18x256xf32>
1205+
/// %collapsed = tensor.collapse_shape %padded [[0, 1], [2]]
1206+
/// tensor<34x18x256xf32> to tensor<612x256xf32>
1207+
///
1208+
/// AFTER:
1209+
/// %collapsed = tensor.collapse_shape %input [[0, 1], [2]]
1210+
/// tensor<32x16x256xf32> to tensor<512x256xf32>
1211+
/// %padded = tensor.pad %collapsed low[1, 0] high[1, 0]
1212+
/// tensor<512x256xf32> to tensor<514x256xf32>
1213+
/// ```
1214+
struct MoveReshapeWithProducerPadOpByCollapsing
11871215
: public OpRewritePattern<tensor::CollapseShapeOp> {
11881216

1189-
FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
1217+
MoveReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
11901218
ControlFusionFn foldReshapes,
11911219
PatternBenefit benefit = 1)
11921220
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
@@ -2394,7 +2422,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
23942422
controlFoldingReshapes);
23952423
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
23962424
controlFoldingReshapes);
2397-
patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
2425+
patterns.add<MoveReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
23982426
controlFoldingReshapes);
23992427
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
24002428
controlFoldingReshapes);
@@ -2407,10 +2435,7 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
24072435
controlFoldingReshapes);
24082436
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
24092437
patterns.getContext(), controlFoldingReshapes);
2410-
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2411-
patterns.getContext(), controlFoldingReshapes);
2412-
2413-
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2438+
patterns.add<MoveReshapeWithProducerPadOpByCollapsing>(
24142439
patterns.getContext(), controlFoldingReshapes);
24152440
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
24162441
controlFoldingReshapes);

0 commit comments

Comments
 (0)