@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
225
225
226
226
// -----
227
227
228
+ #map0 = affine_map <(d0 , d1 ) -> (d0 , d1 )>
229
+ func.func @fuse_by_collapsing_dynamic_2 (%arg0 : tensor <?xf32 >, %sz0: index , %sz1: index ) -> tensor <?x?xf32 > {
230
+ %0 = tensor.expand_shape %arg0 [[0 , 1 ]] output_shape [%sz0 , %sz1 ] : tensor <?xf32 > into tensor <?x?xf32 >
231
+ %init = tensor.empty (%sz1 , %sz0 ) : tensor <?x?xf32 >
232
+ %1 = linalg.generic {
233
+ indexing_maps = [#map0 , #map0 ],
234
+ iterator_types = [" parallel" , " parallel" ]}
235
+ ins (%0 : tensor <?x?xf32 >)
236
+ outs (%init : tensor <?x?xf32 >) {
237
+ ^bb0 (%b0 : f32 , %b1 : f32 ):
238
+ %out = arith.negf %b0 : f32
239
+ linalg.yield %out : f32
240
+ } -> tensor <?x?xf32 >
241
+ return %1 : tensor <?x?xf32 >
242
+ }
243
+
244
+ // CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
245
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
246
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
247
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
248
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
249
+ // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
250
+ // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
251
+ // CHECK: %[[OUT:.+]] = linalg.generic
252
+ // CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
253
+ // CHECK-SAME: outs(%{{.*}} : tensor<?xf32>)
254
+ // CHECK: %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
255
+ // CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]]]
256
+ // CHECK: return %[[EXPANDED_1]]
257
+
258
+ // -----
259
+
228
260
#map0 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>
229
261
#map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d3 )>
230
262
func.func @fuse_reductions (%arg0 : tensor <2 x?x5 xf32 >, %arg1 : tensor <2 x5 xf32 >, %sz0: index ) -> tensor <2 x5 xf32 > {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
425
457
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
426
458
// CHECK: func @fuse_only_one_reassociation
427
459
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
428
- // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
429
460
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
430
- // CHECK-DAG: %[[C1 :.*]] = arith.constant 1 : index
461
+ // CHECK-DAG: %[[C0 :.*]] = arith.constant 0 : index
431
462
// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
463
+ // CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
464
+ // CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
432
465
// CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
433
466
// CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
434
467
// CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
437
470
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
438
471
// CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
439
472
// CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] :
440
- // CHECK: %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
441
- // CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
442
- // CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
443
- // CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
473
+ // CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
444
474
// CHECK: return %[[EXPANDED_3]]
445
475
446
476
// -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
475
505
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
476
506
// CHECK: func @fold_non_consecutive_dims(
477
507
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
478
- // CHECK: %[[C1:.+]] = arith.constant 1 : index
479
- // CHECK: %[[C4:.+]] = arith.constant 4 : index
480
- // CHECK: %[[C8:.+]] = arith.constant 8 : index
481
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
482
- // CHECK: %[[C2:.+]] = arith.constant 2 : index
508
+ // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
509
+ // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
510
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
511
+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
483
512
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
484
- // CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
485
- // CHECK: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
513
+ // CHECK-DAG : %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
514
+ // CHECK-DAG : %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
486
515
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
516
+ // CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
517
+ // CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
487
518
// CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
488
519
// CHECK: %[[GENERIC:.+]] = linalg.generic
489
520
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
502
533
// CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
503
534
// CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]]
504
535
// CHECK: linalg.yield %[[T7]]
505
- // CHECK: %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
506
- // CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
507
- // CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
508
- // CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C4]] : index
509
- // CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
536
+ // CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
510
537
// CHECK: return %[[EXPANDED_3]]
511
538
512
539
// -----
0 commit comments