@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
225225
226226// -----
227227
228+ #map0 = affine_map <(d0 , d1 ) -> (d0 , d1 )>
229+ func.func @fuse_by_collapsing_dynamic_2 (%arg0 : tensor <?xf32 >, %sz0: index , %sz1: index ) -> tensor <?x?xf32 > {
230+ %0 = tensor.expand_shape %arg0 [[0 , 1 ]] output_shape [%sz0 , %sz1 ] : tensor <?xf32 > into tensor <?x?xf32 >
231+ %init = tensor.empty (%sz1 , %sz0 ) : tensor <?x?xf32 >
232+ %1 = linalg.generic {
233+ indexing_maps = [#map0 , #map0 ],
234+ iterator_types = [" parallel" , " parallel" ]}
235+ ins (%0 : tensor <?x?xf32 >)
236+ outs (%init : tensor <?x?xf32 >) {
237+ ^bb0 (%b0 : f32 , %b1 : f32 ):
238+ %out = arith.negf %b0 : f32
239+ linalg.yield %out : f32
240+ } -> tensor <?x?xf32 >
241+ return %1 : tensor <?x?xf32 >
242+ }
243+
244+ // CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
245+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
246+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
247+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
248+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
249+ // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
250+ // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
251+ // CHECK: %[[OUT:.+]] = linalg.generic
252+ // CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
253+ // CHECK-SAME: outs(%{{.*}} : tensor<?xf32>)
254+ // CHECK: %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
255+ // CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]]]
256+ // CHECK: return %[[EXPANDED_1]]
257+
258+ // -----
259+
228260#map0 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>
229261#map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d3 )>
230262func.func @fuse_reductions (%arg0 : tensor <2 x?x5 xf32 >, %arg1 : tensor <2 x5 xf32 >, %sz0: index ) -> tensor <2 x5 xf32 > {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
425457// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
426458// CHECK: func @fuse_only_one_reassociation
427459// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
428- // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
429460// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
430- // CHECK-DAG: %[[C1 :.*]] = arith.constant 1 : index
461+ // CHECK-DAG: %[[C0 :.*]] = arith.constant 0 : index
431462// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
463+ // CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
464+ // CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
432465// CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
433466// CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
434467// CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
437470// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
438471// CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
439472// CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] :
440- // CHECK: %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
441- // CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
442- // CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C8]] : index
443- // CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
473+ // CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
444474// CHECK: return %[[EXPANDED_3]]
445475
446476// -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
475505// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
476506// CHECK: func @fold_non_consecutive_dims(
477507// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
478- // CHECK: %[[C1:.+]] = arith.constant 1 : index
479- // CHECK: %[[C4:.+]] = arith.constant 4 : index
480- // CHECK: %[[C8:.+]] = arith.constant 8 : index
481- // CHECK: %[[C0:.+]] = arith.constant 0 : index
482- // CHECK: %[[C2:.+]] = arith.constant 2 : index
508+ // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
509+ // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
510+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
511+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
483512// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
484- // CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
485- // CHECK: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
513+ // CHECK-DAG : %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
514+ // CHECK-DAG : %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
486515// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
516+ // CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
517+ // CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
487518// CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
488519// CHECK: %[[GENERIC:.+]] = linalg.generic
489520// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
502533// CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
503534// CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]]
504535// CHECK: linalg.yield %[[T7]]
505- // CHECK: %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
506- // CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
507- // CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C8]] : index
508- // CHECK: %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C4]] : index
509- // CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
536+ // CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
510537// CHECK: return %[[EXPANDED_3]]
511538
512539// -----
0 commit comments