@@ -1100,12 +1100,26 @@ class FoldPadWithProducerReshapeOpByExpansion
11001100 ControlFusionFn controlFoldingReshapes;
11011101};
11021102
1103- // / Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
1103+ // / Pattern to move a tensor.expand_shape op with its producer tensor.pad op
11041104// / by bubbling the expand_shape before the pad.
1105- struct FoldReshapeWithProducerPadOpByExpansion
1105+ // /
1106+ // / ```
1107+ // / BEFORE:
1108+ // / %padded = tensor.pad %input low[0, 1, 1] high[0, 1, 1]
1109+ // / tensor<512x256x256xf32> to tensor<512x258x258xf32>
1110+ // / %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]]
1111+ // / tensor<512x258x258xf32> to tensor<32x16x258x258xf32>
1112+ // /
1113+ // / AFTER:
1114+ // / %expanded = tensor.expand_shape %input [[0, 1], [2], [3]]
1115+ // / tensor<512x256x256xf32> to tensor<32x16x256x256xf32>
1116+ // / %padded = tensor.pad %expanded low[0, 0, 1, 1] high[0, 0, 1, 1]
1117+ // / tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
1118+ // / ```
1119+ struct MoveReshapeWithProducerPadOpByExpansion
11061120 : public OpRewritePattern<tensor::ExpandShapeOp> {
11071121
1108- FoldReshapeWithProducerPadOpByExpansion (MLIRContext *context,
1122+ MoveReshapeWithProducerPadOpByExpansion (MLIRContext *context,
11091123 ControlFusionFn foldReshapes,
11101124 PatternBenefit benefit = 1 )
11111125 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
@@ -1181,12 +1195,26 @@ struct FoldReshapeWithProducerPadOpByExpansion
11811195 ControlFusionFn controlFoldingReshapes;
11821196};
11831197
1184- // / Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
1198+ // / Pattern to move a tensor.collapse_shape op with its producer tensor.pad op
11851199// / by bubbling the collapse_shape before the pad.
1186- struct FoldReshapeWithProducerPadOpByCollapsing
1200+ // /
1201+ // / ```
1202+ // / BEFORE:
1203+ // / %padded = tensor.pad %input low[1, 1, 0] high[1, 1, 0]
1204+ // / tensor<32x16x256xf32> to tensor<34x18x256xf32>
1205+ // / %collapsed = tensor.collapse_shape %padded [[0, 1], [2]]
1206+ // / tensor<34x18x256xf32> to tensor<612x256xf32>
1207+ // /
1208+ // / AFTER:
1209+ // / %collapsed = tensor.collapse_shape %input [[0, 1], [2]]
1210+ // / tensor<32x16x256xf32> to tensor<512x256xf32>
1211+ // / %padded = tensor.pad %collapsed low[1, 0] high[1, 0]
1212+ // / tensor<512x256xf32> to tensor<514x256xf32>
1213+ // / ```
1214+ struct MoveReshapeWithProducerPadOpByCollapsing
11871215 : public OpRewritePattern<tensor::CollapseShapeOp> {
11881216
1189- FoldReshapeWithProducerPadOpByCollapsing (MLIRContext *context,
1217+ MoveReshapeWithProducerPadOpByCollapsing (MLIRContext *context,
11901218 ControlFusionFn foldReshapes,
11911219 PatternBenefit benefit = 1 )
11921220 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
@@ -2394,7 +2422,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
23942422 controlFoldingReshapes);
23952423 patterns.add <FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext (),
23962424 controlFoldingReshapes);
2397- patterns.add <FoldReshapeWithProducerPadOpByExpansion >(patterns.getContext (),
2425+ patterns.add <MoveReshapeWithProducerPadOpByExpansion >(patterns.getContext (),
23982426 controlFoldingReshapes);
23992427 patterns.add <FoldWithProducerReshapeOpByExpansion>(patterns.getContext (),
24002428 controlFoldingReshapes);
@@ -2407,10 +2435,7 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
24072435 controlFoldingReshapes);
24082436 patterns.add <FoldPadWithProducerReshapeOpByCollapsing>(
24092437 patterns.getContext (), controlFoldingReshapes);
2410- patterns.add <FoldReshapeWithProducerPadOpByCollapsing>(
2411- patterns.getContext (), controlFoldingReshapes);
2412-
2413- patterns.add <FoldReshapeWithProducerPadOpByCollapsing>(
2438+ patterns.add <MoveReshapeWithProducerPadOpByCollapsing>(
24142439 patterns.getContext (), controlFoldingReshapes);
24152440 patterns.add <FoldReshapeWithGenericOpByCollapsing>(patterns.getContext (),
24162441 controlFoldingReshapes);
0 commit comments