Skip to content

Commit 0d8c636

Browse files
committed
Add FoldReshapeWithProducerPadOpByExpansion
1 parent 9e26c79 commit 0d8c636

File tree

2 files changed

+191
-2
lines changed

2 files changed

+191
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,146 @@ class FoldPadWithProducerReshapeOpByExpansion
11011101
ControlFusionFn controlFoldingReshapes;
11021102
};
11031103

1104+
/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
1105+
/// by bubbling the expand_shape before the pad.
1106+
struct FoldReshapeWithProducerPadOpByExpansion
1107+
: public OpRewritePattern<tensor::ExpandShapeOp> {
1108+
1109+
FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1110+
ControlFusionFn foldReshapes,
1111+
PatternBenefit benefit = 1)
1112+
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1113+
controlFoldingReshapes(std::move(foldReshapes)) {}
1114+
1115+
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1116+
PatternRewriter &rewriter) const override {
1117+
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1118+
if (!padOp)
1119+
return failure();
1120+
1121+
if (!padOp->hasOneUse())
1122+
return failure();
1123+
1124+
if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1125+
return rewriter.notifyMatchFailure(expandOp,
1126+
"fusion blocked by control function");
1127+
}
1128+
1129+
SmallVector<ReassociationIndices> reassociations =
1130+
expandOp.getReassociationIndices();
1131+
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
1132+
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
1133+
1134+
auto isZeroPadding = [](OpFoldResult padValue) -> bool {
1135+
if (auto attr = dyn_cast<Attribute>(padValue)) {
1136+
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
1137+
return intAttr.getInt() == 0;
1138+
}
1139+
1140+
if (auto val = dyn_cast<Value>(padValue)) {
1141+
if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
1142+
if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
1143+
return attr.getInt() == 0;
1144+
}
1145+
}
1146+
1147+
// when padding is dynamic and not constant, we don't know if it's zero or
1148+
// not. so we return false here.
1149+
return false;
1150+
};
1151+
1152+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1153+
OpFoldResult l = low[idx];
1154+
OpFoldResult h = high[idx];
1155+
if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
1156+
return failure();
1157+
}
1158+
1159+
SmallVector<OpFoldResult> newLow, newHigh;
1160+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1161+
for (size_t i = 0; i < reInd.size(); ++i) {
1162+
newLow.push_back(padOp.getMixedLowPad()[idx]);
1163+
newHigh.push_back(padOp.getMixedHighPad()[idx]);
1164+
}
1165+
}
1166+
1167+
Location loc = expandOp.getLoc();
1168+
auto finalType = cast<RankedTensorType>(expandOp.getType());
1169+
ArrayRef<int64_t> finalShape = finalType.getShape();
1170+
1171+
SmallVector<OpFoldResult> expandedShape;
1172+
for (int64_t dimSize : finalShape) {
1173+
if (dimSize == ShapedType::kDynamic) {
1174+
expandedShape.push_back(OpFoldResult{});
1175+
} else {
1176+
expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
1177+
}
1178+
}
1179+
1180+
for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
1181+
OpFoldResult l = low[inDimIdx];
1182+
OpFoldResult h = high[inDimIdx];
1183+
1184+
if (!isZeroPadding(l) || !isZeroPadding(h)) {
1185+
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
1186+
int64_t originalSize = srcType.getDimSize(inDimIdx);
1187+
1188+
OpFoldResult originalSizeOFR;
1189+
if (originalSize == ShapedType::kDynamic) {
1190+
Value orgSizeVal =
1191+
rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
1192+
originalSizeOFR = orgSizeVal;
1193+
} else {
1194+
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
1195+
}
1196+
1197+
for (auto outDimIdx : outGroup) {
1198+
expandedShape[outDimIdx] = originalSizeOFR;
1199+
}
1200+
}
1201+
}
1202+
1203+
for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
1204+
if (dimSize == ShapedType::kDynamic &&
1205+
!isa<Value>(expandedShape[outDimIdx]) &&
1206+
!isa<Attribute>(expandedShape[outDimIdx])) {
1207+
Value actualSize =
1208+
rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
1209+
expandedShape[outDimIdx] = actualSize;
1210+
}
1211+
}
1212+
1213+
SmallVector<int64_t> staticExpandedShape;
1214+
for (OpFoldResult dim : expandedShape) {
1215+
if (auto attr = dyn_cast<Attribute>(dim)) {
1216+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
1217+
staticExpandedShape.push_back(intAttr.getInt());
1218+
} else {
1219+
staticExpandedShape.push_back(ShapedType::kDynamic);
1220+
}
1221+
} else {
1222+
staticExpandedShape.push_back(ShapedType::kDynamic);
1223+
}
1224+
}
1225+
1226+
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
1227+
loc,
1228+
RankedTensorType::get(staticExpandedShape,
1229+
padOp.getSource().getType().getElementType()),
1230+
padOp.getSource(), reassociations);
1231+
1232+
auto newPadOp = rewriter.create<tensor::PadOp>(
1233+
loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
1234+
padOp.getConstantPaddingValue(), padOp.getNofold());
1235+
1236+
rewriter.replaceOp(expandOp, newPadOp.getResult());
1237+
return success();
1238+
}
1239+
1240+
private:
1241+
ControlFusionFn controlFoldingReshapes;
1242+
};
1243+
11041244
/// Pattern to fold a tensor.expand_shape op with its producer generic op
11051245
/// by expanding the dimensionality of the loop in the producer op.
11061246
struct FoldReshapeWithGenericOpByExpansion
@@ -2249,6 +2389,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
22492389
controlFoldingReshapes);
22502390
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
22512391
controlFoldingReshapes);
2392+
patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
2393+
controlFoldingReshapes);
22522394
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
22532395
controlFoldingReshapes);
22542396
}

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
247247

248248
#map0 = affine_map<(d0, d1) -> (d0, d1)>
249249
func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
250-
%arg1 : tensor<?x?xi32>,
250+
%arg1 : tensor<?x?xi32>,
251251
%sz0: index, %sz1: index) ->
252252
tensor<?x?x4x5xi32>
253253
{
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
515515
// -----
516516

517517
func.func @reshape_as_consumer_permutation_with_multiple_results
518-
(%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
518+
(%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
519519
%sz1: index, %sz2: index, %sz3: index, %sz4: index)
520520
-> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
521521
%c:2 = linalg.generic {
@@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
893893
// CHECK: %[[GENERIC:.+]] = linalg.generic
894894
// CHECK-SAME: ins(%[[EXPANDED]] :
895895
// CHECK: return %[[GENERIC]]
896+
897+
// -----
898+
899+
func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
900+
%cst = arith.constant 0.000000e+00 : f32
901+
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
902+
%padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
903+
^bb0(%i: index, %j: index, %k: index):
904+
tensor.yield %cst : f32
905+
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
906+
%expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
907+
return %expanded : tensor<32x16x258x258xf32>
908+
}
909+
// CHECK: func @fold_tensor_pad_with_expand(
910+
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
911+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
912+
// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
913+
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
914+
// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
915+
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
916+
// CHECK: tensor.yield %[[CST]] : f32
917+
// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
918+
// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32>
919+
920+
// -----
921+
922+
func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
923+
%cst = arith.constant 0.000000e+00 : f32
924+
%c0 = arith.constant 0 : index
925+
%c1 = arith.constant 1 : index
926+
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
927+
%padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
928+
^bb0(%i: index, %j: index, %k: index):
929+
tensor.yield %cst : f32
930+
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
931+
%expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
932+
return %expanded : tensor<32x16x258x258xf32>
933+
}
934+
// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
935+
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
936+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
937+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
938+
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
939+
// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
940+
// CHECK: ^bb0(
941+
// CHECK: tensor.yield %[[CST]] : f32
942+
// CHECK: return %[[PADDED]]

0 commit comments

Comments
 (0)