diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 616d4a7d0a0ab..a6ae728b20fa4 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern { if (!dim.has_value()) return failure(); - // Skip static dims. These are folded to constant ops. - RankedTensorType resultType = expandShapeOp.getResultType(); - if (!resultType.isDynamicDim(*dim)) - return failure(); - - // Find reassociation group that contains this result dimension. - int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim); - - // `dim` is the only dynamic dimension in `group`. (Otherwise, the - // ExpandShapeOp would be ambiguous.) - int64_t product = 1; - ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim]; - for (int64_t d : grp) { - if (d != dim) { - assert(!resultType.isDynamicDim(d) && "expected static dim"); - product *= resultType.getDimSize(d); - } - } - - // result dim size = src dim size / (product(other dims in reassoc group)) - Value srcDimSz = - rewriter.create(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim); - AffineExpr expr; - bindSymbols(dimOp.getContext(), expr); - rewriter.replaceOpWithNewOp( - dimOp, expr.floorDiv(product), srcDimSz); + SmallVector outputShape = + getMixedValues(expandShapeOp.getStaticOutputShape(), + expandShapeOp.getOutputShape(), rewriter); + OpFoldResult outputDim = outputShape[dim.value()]; + rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp( + rewriter, dimOp.getLoc(), outputDim)); return success(); } }; diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 0b54c207dea84..3a0f8e0e073ac 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2278,13 +2278,9 @@ func.func @empty_tensor_canonicalize(%i : index) { // ----- -// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)> // CHECK-LABEL: func @dim_of_expand_shape( -// CHECK-SAME: %[[t:.*]]: tensor -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor -// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]] -// CHECK: return %[[apply]] +// CHECK-SAME: %{{.*}}: tensor, %{{.*}}: index, %[[ARG2:.+]]: index +// CHECK: return %[[ARG2]] func.func @dim_of_expand_shape(%t: tensor, %sz0: index, %sz1: index) -> index { %c2 = arith.constant 2 : index %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]