@@ -2403,6 +2403,53 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
24032403
24042404// -----
24052405
2406+ // CHECK-LABEL: @reshape_fold_2d
2407+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2408+ func.func @reshape_fold_2d (%arg0 : tensor <?x?xi32 >) -> tensor <?x?xi32 > {
2409+ %c0 = arith.constant 0 : index
2410+ %c1 = arith.constant 1 : index
2411+ %d0 = tensor.dim %arg0 , %c0 : tensor <?x?xi32 >
2412+ %d1 = tensor.dim %arg0 , %c1 : tensor <?x?xi32 >
2413+ %ds = tensor.from_elements %d0 , %d1 : tensor <2 xindex >
2414+ %reshape = tensor.reshape %arg0 (%ds ) : (tensor <?x?xi32 >, tensor <2 xindex >) -> tensor <?x?xi32 >
2415+ // CHECK: return %[[ARG0]]
2416+ return %reshape : tensor <?x?xi32 >
2417+ }
2418+
2419+ // -----
2420+
2421+ // CHECK-LABEL: @reshape_nofold_2d
2422+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2423+ func.func @reshape_nofold_2d (%arg0 : tensor <?x?xi32 >) -> tensor <?x?xi32 > {
2424+ %c0 = arith.constant 0 : index
2425+ %c1 = arith.constant 1 : index
2426+ %d0 = tensor.dim %arg0 , %c0 : tensor <?x?xi32 >
2427+ %d1 = tensor.dim %arg0 , %c1 : tensor <?x?xi32 >
2428+ %ds = tensor.from_elements %d1 , %d0 : tensor <2 xindex >
2429+ // CHECK: tensor.reshape
2430+ %reshape = tensor.reshape %arg0 (%ds ) : (tensor <?x?xi32 >, tensor <2 xindex >) -> tensor <?x?xi32 >
2431+ return %reshape : tensor <?x?xi32 >
2432+ }
2433+
2434+
2435+ // -----
2436+
2437+ // CHECK-LABEL: @reshape_fold_3d_cst
2438+ // CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
2439+ func.func @reshape_fold_3d_cst (%arg0 : tensor <5 x?x?xi32 >) -> tensor <5 x?x?xi32 > {
2440+ %c1 = arith.constant 1 : index
2441+ %c2 = arith.constant 2 : index
2442+ %d0 = arith.constant 5 : index
2443+ %d1 = tensor.dim %arg0 , %c1 : tensor <5 x?x?xi32 >
2444+ %d2 = tensor.dim %arg0 , %c2 : tensor <5 x?x?xi32 >
2445+ %ds = tensor.from_elements %d0 , %d1 , %d2 : tensor <3 xindex >
2446+ %reshape = tensor.reshape %arg0 (%ds ) : (tensor <5 x?x?xi32 >, tensor <3 xindex >) -> tensor <5 x?x?xi32 >
2447+ // CHECK: return %[[ARG0]]
2448+ return %reshape : tensor <5 x?x?xi32 >
2449+ }
2450+
2451+ // -----
2452+
24062453// Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
24072454// CHECK-LABEL: func @dim_out_of_bounds(
24082455// CHECK: %[[IDX:.*]] = index.constant 28
0 commit comments