Skip to content

Commit 3b91645

Browse files
committed
fix upon review
1 parent 9ee8e08 commit 3b91645

File tree

1 file changed

+4
-17
lines changed

1 file changed

+4
-17
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,23 +1136,19 @@ struct FoldReshapeWithProducerPadOpByExpansion
11361136
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
11371137
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
11381138

1139-
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1140-
if (reInd.size() > 1 &&
1141-
(!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)))
1139+
SmallVector<OpFoldResult> newLow, newHigh;
1140+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1141+
if (reInd.size() > 1 && (!isConstantIntValue(low[idx], 0) ||
1142+
!isConstantIntValue(high[idx], 0)))
11421143
return rewriter.notifyMatchFailure(
11431144
expandOp, "fusion blocked by non-zero padding");
1144-
}
11451145

1146-
SmallVector<OpFoldResult> newLow, newHigh;
1147-
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
11481146
newLow.append(reInd.size(), low[idx]);
11491147
newHigh.append(reInd.size(), high[idx]);
11501148
}
11511149

11521150
Location loc = expandOp.getLoc();
1153-
ArrayRef<int64_t> finalShape = expandOp.getResultType().getShape();
11541151
SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();
1155-
11561152
for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
11571153
OpFoldResult l = low[inDimIdx];
11581154
OpFoldResult h = high[inDimIdx];
@@ -1165,15 +1161,6 @@ struct FoldReshapeWithProducerPadOpByExpansion
11651161
}
11661162
}
11671163

1168-
for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
1169-
if (dimSize == ShapedType::kDynamic &&
1170-
!isa<Value>(expandedShape[outDimIdx]) &&
1171-
!isa<Attribute>(expandedShape[outDimIdx])) {
1172-
expandedShape[outDimIdx] =
1173-
tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
1174-
}
1175-
}
1176-
11771164
SmallVector<int64_t> staticExpandedShape;
11781165
std::tie(staticExpandedShape, std::ignore) =
11791166
decomposeMixedValues(expandedShape);

0 commit comments

Comments
 (0)