diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index d589f627d896e..b42e60d5cebd7 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1986,90 +1986,6 @@ struct FoldCollapseOfCastOp : public OpRewritePattern { } }; -struct FoldDimOfExpandShape : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - auto expandShapeOp = dimOp.getSource().getDefiningOp(); - if (!expandShapeOp) - return failure(); - - // Only constant dimension values are supported. - std::optional dim = dimOp.getConstantIndex(); - if (!dim.has_value()) - return failure(); - - // Skip static dims. These are folded to constant ops. - RankedTensorType resultType = expandShapeOp.getResultType(); - if (!resultType.isDynamicDim(*dim)) - return failure(); - - // Find reassociation group that contains this result dimension. - int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim); - - // `dim` is the only dynamic dimension in `group`. (Otherwise, the - // ExpandShapeOp would be ambiguous.) - int64_t product = 1; - ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim]; - for (int64_t d : grp) { - if (d != dim) { - assert(!resultType.isDynamicDim(d) && "expected static dim"); - product *= resultType.getDimSize(d); - } - } - - // result dim size = src dim size / (product(other dims in reassoc group)) - Value srcDimSz = - rewriter.create(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim); - AffineExpr expr; - bindSymbols(dimOp.getContext(), expr); - rewriter.replaceOpWithNewOp( - dimOp, expr.floorDiv(product), srcDimSz); - return success(); - } -}; - -struct FoldDimOfCollapseShape : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - auto collapseShapeOp = dimOp.getSource().getDefiningOp(); - if (!collapseShapeOp) - return failure(); - - // Only constant dimension values are supported. - std::optional dim = dimOp.getConstantIndex(); - if (!dim.has_value() || - dim.value() >= collapseShapeOp.getResultType().getRank()) - return failure(); - - // Skip static dims. These are folded to constant ops. - RankedTensorType resultType = collapseShapeOp.getResultType(); - if (!resultType.isDynamicDim(*dim)) - return failure(); - - // Get reassociation group of the result dimension. - ReassociationIndices group = - collapseShapeOp.getReassociationIndices()[*dim]; - - // result dim size = product(dims in reassoc group) - SmallVector srcDimSizes; - SmallVector syms; - AffineExpr product; - for (const auto &it : llvm::enumerate(group)) { - srcDimSizes.push_back(rewriter.create( - dimOp.getLoc(), collapseShapeOp.getSrc(), it.value())); - syms.push_back(rewriter.getAffineSymbolExpr(it.index())); - product = product ? product * syms.back() : syms.back(); - } - rewriter.replaceOpWithNewOp(dimOp, product, - srcDimSizes); - return success(); - } -}; - /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by /// matching constant output_shape operands of the expand. This makes the /// `tensor.expand_shape` more static and creates a consumer cast that can be @@ -2158,8 +2074,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, ComposeExpandOfCollapseOp, ConvertToStaticExpandShape, FoldReshapeWithConstant, FoldReshapeWithSplat, - FoldReshapeWithFromElements, FoldDimOfExpandShape, - FoldDimOfCollapseShape>(context); + FoldReshapeWithFromElements>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 3256daa8e0b59..a00c798197e5a 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -25,10 +25,8 @@ func.func @drop_one_trip_loops(%arg0 : tensor, %arg1 : f32, %shape: t // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK-LABEL: func @drop_one_trip_loops // CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]] // CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] @@ -36,11 +34,9 @@ func.func @drop_one_trip_loops(%arg0 : tensor, %arg1 : f32, %shape: t // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C0]] -// CHECK: %[[VAL_1:.*]] = affine.apply #[[$MAP4]]()[%[[DIM]], %[[C1]]] // CHECK: %[[DIM_1:.*]] = tensor.dim %{{.*}}, %[[C2]] -// CHECK: %[[VAL_2:.*]] = affine.apply #[[$MAP4]]()[%[[DIM_1]], %[[C1]]] // CHECK: %[[DIM_2:.*]] = tensor.dim %{{.*}}, %[[C2]] -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[VAL_1]], 1, %[[VAL_2]], 1, %[[DIM_2]]] : tensor into tensor +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[DIM]], 1, %[[DIM_1]], 1, %[[DIM_2]]] : tensor into tensor // CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> @@ -79,18 +75,15 @@ func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, } // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 * s1) * s2) * s3) * s4)> // CHECK-LABEL: func @drop_one_trip_loops_all_ones // CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: tensor.collapse_shape %{{.*}} [] // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] // CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x1x?x1x1xf32> -// CHECK: %[[SZ:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[C1]], %[[DIM]], %[[C1]], %[[C1]]] -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[SZ]], 1, 1] : tensor into tensor<1x1x?x1x1xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[DIM]], 1, 1] : tensor into tensor<1x1x?x1x1xf32> // ----- @@ -406,7 +399,6 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)> // CHECK: func @unit_dim_for_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32> // CHECK: %[[C1:.+]] = arith.constant 1 : index @@ -422,8 +414,7 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) // CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x1x?xf32> -// CHECK: %[[VAL_3:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[DIM_0]], %[[C1]]] -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_3]]] : tensor into tensor<1x?xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[DIM_0]]] : tensor into tensor<1x?xf32> // CHECK: return %[[EXPANDED]] : tensor<1x?xf32> // ----- @@ -482,10 +473,8 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor) -> tensor (d0, d1)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK: func @unit_dim_for_reduction_inner // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[C2:.*]] = arith.constant 2 : index @@ -499,8 +488,7 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor) -> tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) // CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[VAL_3:.+]] = affine.apply #[[$MAP3]]()[%[[DIM_0]], %[[C1]]] -// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[VAL_3]], 1] : tensor into tensor +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[DIM_0]], 1] : tensor into tensor // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -1017,7 +1005,6 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32> return %0 : tensor<1x?xf32> } -// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 11)> // CHECK-LABEL: func @drop_unit_pad_dynamic_dims // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -1027,8 +1014,7 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32> // CHECK: %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[5] high[6] // CHECK: } : tensor to tensor // CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?xf32> -// CHECK: %[[VAL_0:.+]] = affine.apply #[[$MAP]]()[%[[C1]], %[[DIM]]] -// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[VAL_0]]] +// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[DIM]]] // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_1]]] : tensor into tensor<1x?xf32> // CHECK-SLICES: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 11)> @@ -1090,7 +1076,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // ----- -// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)> // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()> @@ -1098,12 +1083,10 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x?x1xf32>, // CHECK-SAME: %[[ARG1:.*]]: index) -> tensor { // CHECK: %[[VAL_0:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index // CHECK: %[[VAL_2:.*]] = arith.constant dense<1.000000e+00> : tensor // CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] : tensor<1x?x?x1xf32> into tensor // CHECK: %[[VAL_4:.*]] = tensor.empty(%[[ARG1]]) : tensor -// CHECK: %[[VAL_5:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[ARG1]], %[[VAL_1]]] -// CHECK: %[[VAL_6:.*]] = tensor.empty(%[[VAL_5]]) : tensor +// CHECK: %[[VAL_6:.*]] = tensor.empty(%[[ARG1]]) : tensor // CHECK: %[[VAL_7:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[VAL_3]], %[[VAL_2]], %[[VAL_4]] : tensor, tensor, tensor) outs(%[[VAL_6]] : tensor) { // CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32): // CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_8]], %[[VAL_9]] : f32 diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index c68a6362f52c5..43bddb075e649 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -76,13 +76,13 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat // CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) - // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]] // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] // CHECK-NEXT: return %[[RES]] %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>) @@ -134,7 +134,7 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor, %arg1: tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) - // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]] + // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]] // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1] // CHECK-NEXT: return %[[RES]] %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) -> tensor @@ -171,12 +171,12 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] // CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat // CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) - // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]] // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] // CHECK-NEXT: return %[[RES]] %0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index fd96328c6033d..85bf6fba52aa4 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1105,15 +1105,13 @@ func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor) - %expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor into tensor return %expanded : tensor } -// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)> // CHECK-LABEL: @compose_expand_of_collapse_last_two_dims // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index // CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index +// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor into tensor -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor -// CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]] -// CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[COLLAPSE]], %[[CONSTANT0]] : tensor +// CHECK: %[[DIVUI:.+]] = arith.divui %[[DIM]], %[[CONSTANT384]] : index // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor into tensor // CHECK: return %[[RESULT]] @@ -2137,13 +2135,12 @@ func.func @empty_tensor_canonicalize(%i : index) { // ----- -// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)> // CHECK-LABEL: func @dim_of_expand_shape( // CHECK-SAME: %[[t:.*]]: tensor -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor -// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]] -// CHECK: return %[[apply]] +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[expanded:.*]] = tensor.expand_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4, 5]] output_shape [%arg1, 1, %arg2, 5, 1, 8] : tensor into tensor +// CHECK: %[[dim:.*]] = tensor.dim %[[expanded]], %[[c2]] : tensor +// CHECK: return %[[dim]] func.func @dim_of_expand_shape(%t: tensor, %sz0: index, %sz1: index) -> index { %c2 = arith.constant 2 : index %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8] @@ -2154,17 +2151,12 @@ func.func @dim_of_expand_shape(%t: tensor, %sz0: index, %sz1: index) -> // ----- -// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)> // CHECK-LABEL: func @dim_of_collapse_shape( // CHECK-SAME: %[[t:.*]]: tensor // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]] -// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]] -// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]] -// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]] -// CHECK: return %[[apply]] +// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4]] : tensor into tensor +// CHECK-DAG: %[[dim:.*]] = tensor.dim %[[collapsed]], %[[c1]] +// CHECK: return %[[dim]] func.func @dim_of_collapse_shape(%t: tensor) -> index { %c1 = arith.constant 1 : index %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]