Skip to content

Commit a02662c

Browse files
committed
[mlir][Tensor] Use output_shape for DimOp->ExpandShapeOp folding
1 parent ba624b7 commit a02662c

File tree

2 files changed

+8
-32
lines changed

2 files changed

+8
-32
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
19711971
if (!dim.has_value())
19721972
return failure();
19731973

1974-
// Skip static dims. These are folded to constant ops.
1975-
RankedTensorType resultType = expandShapeOp.getResultType();
1976-
if (!resultType.isDynamicDim(*dim))
1977-
return failure();
1978-
1979-
// Find reassociation group that contains this result dimension.
1980-
int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1981-
1982-
// `dim` is the only dynamic dimension in `group`. (Otherwise, the
1983-
// ExpandShapeOp would be ambiguous.)
1984-
int64_t product = 1;
1985-
ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1986-
for (int64_t d : grp) {
1987-
if (d != dim) {
1988-
assert(!resultType.isDynamicDim(d) && "expected static dim");
1989-
product *= resultType.getDimSize(d);
1990-
}
1991-
}
1992-
1993-
// result dim size = src dim size / (product(other dims in reassoc group))
1994-
Value srcDimSz =
1995-
rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1996-
AffineExpr expr;
1997-
bindSymbols(dimOp.getContext(), expr);
1998-
rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1999-
dimOp, expr.floorDiv(product), srcDimSz);
1974+
SmallVector<OpFoldResult> outputShape =
1975+
getMixedValues(expandShapeOp.getStaticOutputShape(),
1976+
expandShapeOp.getOutputShape(), rewriter);
1977+
OpFoldResult outputDim = outputShape[dim.value()];
1978+
rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
1979+
rewriter, dimOp.getLoc(), outputDim));
20001980
return success();
20011981
}
20021982
};

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,13 +2278,9 @@ func.func @empty_tensor_canonicalize(%i : index) {
22782278

22792279
// -----
22802280

2281-
// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
22822281
// CHECK-LABEL: func @dim_of_expand_shape(
2283-
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
2284-
// CHECK: %[[c1:.*]] = arith.constant 1 : index
2285-
// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
2286-
// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
2287-
// CHECK: return %[[apply]]
2282+
// CHECK-SAME: %{{.*}}: tensor<?x?xf32>, %{{.*}}: index, %[[ARG2:.+]]: index
2283+
// CHECK: return %[[ARG2]]
22882284
func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
22892285
%c2 = arith.constant 2 : index
22902286
%0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]

0 commit comments

Comments
 (0)