Skip to content

Commit 57ec657

Browse files
committed
add collapse_shape
1 parent cf1e560 commit 57ec657

File tree

2 files changed

+204
-25
lines changed

2 files changed

+204
-25
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 152 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include "mlir/Support/LLVM.h"
2727
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2828
#include "mlir/Transforms/RegionUtils.h"
29+
#include "llvm/ADT/STLExtras.h"
30+
#include "llvm/Support/LogicalResult.h"
2931
#include <optional>
3032
#include <utility>
3133

@@ -1100,6 +1102,20 @@ class FoldPadWithProducerReshapeOpByExpansion
11001102
ControlFusionFn controlFoldingReshapes;
11011103
};
11021104

1105+
bool isZero(OpFoldResult value) {
1106+
if (auto attr = dyn_cast<Attribute>(value)) {
1107+
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
1108+
return intAttr.getInt() == 0;
1109+
}
1110+
if (auto val = dyn_cast<Value>(value)) {
1111+
if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
1112+
if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
1113+
return attr.getInt() == 0;
1114+
}
1115+
}
1116+
return false;
1117+
}
1118+
11031119
/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
11041120
/// by bubbling the expand_shape before the pad.
11051121
struct FoldReshapeWithProducerPadOpByExpansion
@@ -1125,41 +1141,29 @@ struct FoldReshapeWithProducerPadOpByExpansion
11251141
"fusion blocked by control function");
11261142
}
11271143

1144+
Value constantPaddingValue = padOp.getConstantPaddingValue();
1145+
if (!constantPaddingValue) {
1146+
return rewriter.notifyMatchFailure(
1147+
expandOp, "cannot fold with non-constant padding value");
1148+
}
1149+
11281150
SmallVector<ReassociationIndices> reassociations =
11291151
expandOp.getReassociationIndices();
11301152
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
11311153
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
11321154

1133-
auto isZeroPadding = [](OpFoldResult padValue) -> bool {
1134-
if (auto attr = dyn_cast<Attribute>(padValue)) {
1135-
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
1136-
return intAttr.getInt() == 0;
1137-
}
1138-
1139-
if (auto val = dyn_cast<Value>(padValue)) {
1140-
if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
1141-
if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
1142-
return attr.getInt() == 0;
1143-
}
1144-
}
1145-
1146-
// when padding is dynamic and not constant, we don't know if it's zero or
1147-
// not. so we return false here.
1148-
return false;
1149-
};
1150-
11511155
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
11521156
OpFoldResult l = low[idx];
11531157
OpFoldResult h = high[idx];
1154-
if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
1158+
if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
11551159
return failure();
11561160
}
11571161

11581162
SmallVector<OpFoldResult> newLow, newHigh;
11591163
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
11601164
for (size_t i = 0; i < reInd.size(); ++i) {
1161-
newLow.push_back(padOp.getMixedLowPad()[idx]);
1162-
newHigh.push_back(padOp.getMixedHighPad()[idx]);
1165+
newLow.push_back(low[idx]);
1166+
newHigh.push_back(high[idx]);
11631167
}
11641168
}
11651169

@@ -1176,11 +1180,11 @@ struct FoldReshapeWithProducerPadOpByExpansion
11761180
}
11771181
}
11781182

1179-
for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
1183+
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
11801184
OpFoldResult l = low[inDimIdx];
11811185
OpFoldResult h = high[inDimIdx];
11821186

1183-
if (!isZeroPadding(l) || !isZeroPadding(h)) {
1187+
if (!isZero(l) || !isZero(h)) {
11841188
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
11851189
int64_t originalSize = srcType.getDimSize(inDimIdx);
11861190

@@ -1193,7 +1197,7 @@ struct FoldReshapeWithProducerPadOpByExpansion
11931197
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
11941198
}
11951199

1196-
for (auto outDimIdx : outGroup) {
1200+
for (auto outDimIdx : reInd) {
11971201
expandedShape[outDimIdx] = originalSizeOFR;
11981202
}
11991203
}
@@ -1240,6 +1244,125 @@ struct FoldReshapeWithProducerPadOpByExpansion
12401244
ControlFusionFn controlFoldingReshapes;
12411245
};
12421246

1247+
/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
1248+
/// by bubbling the collapse_shape before the pad.
1249+
struct FoldReshapeWithProducerPadOpByCollapsing
1250+
: public OpRewritePattern<tensor::CollapseShapeOp> {
1251+
1252+
FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
1253+
ControlFusionFn foldReshapes,
1254+
PatternBenefit benefit = 1)
1255+
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1256+
controlFoldingReshapes(std::move(foldReshapes)) {}
1257+
1258+
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
1259+
PatternRewriter &rewriter) const override {
1260+
tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp<tensor::PadOp>();
1261+
1262+
if (!padOp)
1263+
return failure();
1264+
1265+
if (!padOp->hasOneUse())
1266+
return failure();
1267+
1268+
if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) {
1269+
return rewriter.notifyMatchFailure(collapseOp,
1270+
"fusion blocked by control function");
1271+
}
1272+
1273+
Value constantPaddingValue = padOp.getConstantPaddingValue();
1274+
if (!constantPaddingValue) {
1275+
return rewriter.notifyMatchFailure(
1276+
collapseOp, "cannot fold with non-constant padding value");
1277+
}
1278+
1279+
SmallVector<ReassociationIndices> reassociations =
1280+
collapseOp.getReassociationIndices();
1281+
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
1282+
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
1283+
1284+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1285+
if (reInd.size() > 1) {
1286+
for (auto dimIdx : reInd) {
1287+
if (!isZero(low[dimIdx]) || !isZero(high[dimIdx])) {
1288+
return failure();
1289+
}
1290+
}
1291+
}
1292+
}
1293+
1294+
SmallVector<OpFoldResult> newLow, newHigh;
1295+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1296+
newLow.push_back(low[reInd[0]]);
1297+
newHigh.push_back(high[reInd[0]]);
1298+
}
1299+
1300+
Location loc = collapseOp.getLoc();
1301+
auto resultType = collapseOp.getResultType();
1302+
1303+
auto finalType = cast<RankedTensorType>(collapseOp.getType());
1304+
ArrayRef<int64_t> finalShape = finalType.getShape();
1305+
1306+
SmallVector<OpFoldResult> collapsedShape;
1307+
for (int64_t dimSize : finalShape) {
1308+
if (dimSize == ShapedType::kDynamic) {
1309+
collapsedShape.push_back(OpFoldResult{});
1310+
} else {
1311+
collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
1312+
}
1313+
}
1314+
1315+
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
1316+
OpFoldResult l = low[reInd[0]];
1317+
OpFoldResult h = high[reInd[0]];
1318+
1319+
if (!isZero(l) || !isZero(h)) {
1320+
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
1321+
int64_t originalSize = srcType.getDimSize(reInd[0]);
1322+
1323+
OpFoldResult originalSizeOFR;
1324+
if (originalSize == ShapedType::kDynamic) {
1325+
Value orgSizeVal =
1326+
rewriter.create<tensor::DimOp>(loc, padOp.getSource(), reInd[0]);
1327+
originalSizeOFR = orgSizeVal;
1328+
} else {
1329+
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
1330+
}
1331+
collapsedShape[inDimIdx] = originalSizeOFR;
1332+
}
1333+
}
1334+
1335+
SmallVector<int64_t> staticCollapsedShape;
1336+
for (OpFoldResult dim : collapsedShape) {
1337+
if (auto attr = dyn_cast<Attribute>(dim)) {
1338+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
1339+
staticCollapsedShape.push_back(intAttr.getInt());
1340+
} else {
1341+
staticCollapsedShape.push_back(ShapedType::kDynamic);
1342+
}
1343+
} else {
1344+
staticCollapsedShape.push_back(ShapedType::kDynamic);
1345+
}
1346+
}
1347+
1348+
auto newCollapseType = RankedTensorType::get(
1349+
staticCollapsedShape, padOp.getSource().getType().getElementType());
1350+
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
1351+
loc, newCollapseType, padOp.getSource(), reassociations);
1352+
1353+
auto newPadOp = rewriter.create<tensor::PadOp>(
1354+
loc, resultType, newCollapseOp.getResult(), newLow, newHigh,
1355+
padOp.getConstantPaddingValue(), padOp.getNofold());
1356+
1357+
rewriter.replaceOp(collapseOp, newPadOp.getResult());
1358+
1359+
return success();
1360+
}
1361+
1362+
private:
1363+
ControlFusionFn controlFoldingReshapes;
1364+
};
1365+
12431366
/// Pattern to fold a tensor.expand_shape op with its producer generic op
12441367
/// by expanding the dimensionality of the loop in the producer op.
12451368
struct FoldReshapeWithGenericOpByExpansion
@@ -2388,6 +2511,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
23882511
controlFoldingReshapes);
23892512
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
23902513
patterns.getContext(), controlFoldingReshapes);
2514+
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2515+
patterns.getContext(), controlFoldingReshapes);
2516+
2517+
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2518+
patterns.getContext(), controlFoldingReshapes);
23912519
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
23922520
controlFoldingReshapes);
23932521
}

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1
232232
%1 = linalg.generic {
233233
indexing_maps = [#map0, #map0],
234234
iterator_types = ["parallel", "parallel"]}
235-
ins(%0 : tensor<?x?xf32>)
235+
ins(%0 : tensor<?x?xf32>)
236236
outs(%init : tensor<?x?xf32>) {
237237
^bb0(%b0 : f32, %b1 : f32):
238238
%out = arith.negf %b0 : f32
@@ -858,3 +858,54 @@ func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1:
858858
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
859859
// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
860860
// CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32>
861+
862+
// -----
863+
864+
func.func @fold_tensor_pad_with_collapse(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
865+
%cst = arith.constant 0.000000e+00 : f32
866+
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
867+
%padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
868+
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
869+
tensor.yield %cst : f32
870+
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
871+
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
872+
: tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
873+
return %collapsed : tensor<512x258x258xf32>
874+
}
875+
// CHECK: func @fold_tensor_pad_with_collapse(
876+
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
877+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
878+
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
879+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
880+
// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
881+
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
882+
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index):
883+
// CHECK: tensor.yield %[[CST]] : f32
884+
// CHECK: } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
885+
// CHECK: return %[[PADDED]] : tensor<512x258x258xf32>
886+
887+
// -----
888+
889+
func.func @fold_tensor_pad_with_collapse_dynamic_pad_zero(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
890+
%cst = arith.constant 0.000000e+00 : f32
891+
%c0 = arith.constant 0 : index
892+
%c1 = arith.constant 1 : index
893+
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
894+
%padded = tensor.pad %0 low[%c0, %c0, %c1, %c1] high[%c0, %c0, %c1, %c1] {
895+
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
896+
tensor.yield %cst : f32
897+
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
898+
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
899+
: tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
900+
return %collapsed : tensor<512x258x258xf32>
901+
}
902+
// CHECK: func @fold_tensor_pad_with_collapse_dynamic_pad_zero(
903+
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
904+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
905+
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
906+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
907+
// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
908+
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
909+
// CHECK: ^bb0(
910+
// CHECK: tensor.yield %[[CST]] : f32
911+
// CHECK: return %[[PADDED]]

0 commit comments

Comments
 (0)