@@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
368368
369369// -----
370370
371- // `EmptyTensorElimination` fails to find a valid insertion
372- // point for the new injected `SubsetExtraction`.
373- // CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
374- func.func @fail_to_eliminate_any_empty_tensors () -> tensor <5 x6 x128 xf32 > {
371+ // CHECK-LABEL: func.func @eliminate_all_empty_tensors
372+ func.func @eliminate_all_empty_tensors () -> tensor <5 x6 x128 xf32 > {
375373 %cst_1 = arith.constant 1.0 : f32
376374 %cst_2 = arith.constant 2.0 : f32
377- // CHECK: memref.alloc
378- // CHECK: memref.alloc
379- // CHECK: memref.alloc
375+ // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
376+ // CHECK-NOT: memref.alloc
380377 %empty_1 = tensor.empty () : tensor <5 x6 x64 xf32 >
381378 %res_1 = linalg.fill ins (%cst_1 : f32 ) outs (%empty_1 : tensor <5 x6 x64 xf32 >) -> tensor <5 x6 x64 xf32 >
382379 %empty_2 = tensor.empty () : tensor <5 x6 x64 xf32 >
383380 %res_2 = linalg.fill ins (%cst_2 : f32 ) outs (%empty_2 : tensor <5 x6 x64 xf32 >) -> tensor <5 x6 x64 xf32 >
384381 %cancatenated_empty = tensor.empty () : tensor <5 x6 x128 xf32 >
385- // CHECK: memref.copy
382+ // CHECK-NOT : memref.copy
386383 %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty [0 , 0 , 0 ][5 , 6 , 64 ][1 , 1 , 1 ]
387384 : tensor <5 x6 x64 xf32 > into tensor <5 x6 x128 xf32 >
388385 %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1 [0 , 0 , 64 ][5 , 6 , 64 ][1 , 1 , 1 ]
@@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
392389
393390// -----
394391
395- // CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
396- func.func @succeed_to_eliminate_one_empty_tensor () -> tensor <5 x6 x128 xf32 > {
392+ // CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors
393+ func.func @eliminate_concatenated_empty_tensors () -> tensor <5 x6 x128 xf32 > {
397394 %cst_1 = arith.constant 1.0 : f32
398395 %cst_2 = arith.constant 2.0 : f32
399396 // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
400- // CHECK: memref.alloc
401397 // CHECK-NOT: memref.alloc
402- %cancatenated_empty = tensor.empty () : tensor <5 x6 x128 xf32 >
398+ %concatenated_empty = tensor.empty () : tensor <5 x6 x128 xf32 >
403399 %empty_1 = tensor.empty () : tensor <5 x6 x64 xf32 >
404400 %res_1 = linalg.fill ins (%cst_1 : f32 ) outs (%empty_1 : tensor <5 x6 x64 xf32 >) -> tensor <5 x6 x64 xf32 >
405401 %empty_2 = tensor.empty () : tensor <5 x6 x64 xf32 >
406402 %res_2 = linalg.fill ins (%cst_2 : f32 ) outs (%empty_2 : tensor <5 x6 x64 xf32 >) -> tensor <5 x6 x64 xf32 >
407- // CHECK: memref.copy
408- %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty [0 , 0 , 0 ][5 , 6 , 64 ][1 , 1 , 1 ]
403+ // CHECK-NOT : memref.copy
404+ %inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty [0 , 0 , 0 ][5 , 6 , 64 ][1 , 1 , 1 ]
409405 : tensor <5 x6 x64 xf32 > into tensor <5 x6 x128 xf32 >
410406 %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1 [0 , 0 , 64 ][5 , 6 , 64 ][1 , 1 , 1 ]
411407 : tensor <5 x6 x64 xf32 > into tensor <5 x6 x128 xf32 >
@@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
420416
421417// CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty
422418// CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty
419+ // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
420+ // CHECK-NOT: memref.alloc
421+ // CHECK-NOT: memref.copy
422+ // CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0]
423+ // CHECK-ELIM: linalg.fill
424+ // CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64]
425+ // CHECK-ELIM: linalg.fill
423426func.func @multi_use_of_the_same_tensor_empty () -> tensor <5 x6 x128 xf32 > {
424427 %cst_1 = arith.constant 1.0 : f32
425428 %cst_2 = arith.constant 2.0 : f32
426429 %cancatenated_empty = tensor.empty () : tensor <5 x6 x128 xf32 >
427430 %empty_1 = tensor.empty () : tensor <5 x6 x64 xf32 >
428- // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
429- // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
430- // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
431431 %res_1 = linalg.fill ins (%cst_1 : f32 ) outs (%empty_1 : tensor <5 x6 x64 xf32 >) -> tensor <5 x6 x64 xf32 >
432432 %res_2 = linalg.fill ins (%cst_2 : f32 ) outs (%empty_1 : tensor <5 x6 x64 xf32 >) -> tensor <5 x6 x64 xf32 >
433- // CHECK: memref.copy
434433 %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty [0 , 0 , 0 ][5 , 6 , 64 ][1 , 1 , 1 ]
435434 : tensor <5 x6 x64 xf32 > into tensor <5 x6 x128 xf32 >
436- // CHECK-NOT: memref.copy
437435 %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1 [0 , 0 , 64 ][5 , 6 , 64 ][1 , 1 , 1 ]
438436 : tensor <5 x6 x64 xf32 > into tensor <5 x6 x128 xf32 >
439437 return %inserted_slice_2 : tensor <5 x6 x128 xf32 >
@@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x
476474 : tensor <5 x6 x64 xf32 > into tensor <5 x6 x128 xf32 >
477475 return %inserted_slice_1 : tensor <5 x6 x128 xf32 >
478476}
477+
478+ // -----
479+
480+ // Test that dependent pure operations are moved before the
481+ // insertion point to enable empty tensor elimination.
482+
483+ // CHECK-LABEL: func.func @move_dependent_arith_op(
484+ // CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>
485+ // CHECK-SAME: %[[ARG1:.*]]: index
486+ // CHECK-NOT: memref.alloc
487+ // CHECK: %[[C5:.*]] = arith.constant 5 : index
488+ // CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
489+ // CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1]
490+ // CHECK: linalg.fill {{.*}} outs(%[[SV]]
491+ // CHECK: return %[[ARG0]]
492+ // CHECK-ELIM-LABEL: func.func @move_dependent_arith_op(
493+ // CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32>
494+ // CHECK-ELIM-SAME: %[[ARG1:.*]]: index
495+ // CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index
496+ // CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
497+ // CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1]
498+ // CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]]
499+ // CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]]
500+ func.func @move_dependent_arith_op (
501+ %arg0: tensor <10 xf32 > {bufferization.buffer_layout = affine_map <(d0 ) -> (d0 )>, bufferization.writable = true },
502+ %arg1: index , %f: f32 ) -> tensor <10 xf32 >
503+ {
504+ %0 = tensor.empty () : tensor <5 xf32 >
505+ %1 = linalg.fill ins (%f : f32 ) outs (%0 : tensor <5 xf32 >) -> tensor <5 xf32 >
506+ %c5 = arith.constant 5 : index
507+ %offset = arith.addi %arg1 , %c5 : index
508+ %2 = tensor.insert_slice %1 into %arg0 [%offset ][5 ][1 ]
509+ : tensor <5 xf32 > into tensor <10 xf32 >
510+ return %2 : tensor <10 xf32 >
511+ }
512+
513+ // -----
514+
515+ // Test that side-effecting operations are not moved, preventing empty
516+ // tensor elimination.
517+
518+ // CHECK-LABEL: func.func @side_effecting_op_blocks_movement(
519+ // CHECK: memref.alloc
520+ // CHECK: linalg.fill
521+ // CHECK: memref.load
522+ // CHECK: memref.subview
523+ // CHECK: memref.copy
524+ // CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement(
525+ // CHECK-ELIM: tensor.empty
526+ // CHECK-ELIM: linalg.fill
527+ // CHECK-ELIM: memref.load
528+ // CHECK-ELIM: tensor.insert_slice
529+ func.func @side_effecting_op_blocks_movement (
530+ %arg0: tensor <10 xf32 > {bufferization.buffer_layout = affine_map <(d0 ) -> (d0 )>, bufferization.writable = true },
531+ %mem: memref <index >, %f: f32 ) -> tensor <10 xf32 >
532+ {
533+ %0 = tensor.empty () : tensor <5 xf32 >
534+ %1 = linalg.fill ins (%f : f32 ) outs (%0 : tensor <5 xf32 >) -> tensor <5 xf32 >
535+ %offset = memref.load %mem [] : memref <index >
536+ %2 = tensor.insert_slice %1 into %arg0 [%offset ][5 ][1 ]
537+ : tensor <5 xf32 > into tensor <10 xf32 >
538+ return %2 : tensor <10 xf32 >
539+ }
0 commit comments