@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
678678// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
679679// CHECK: %[[ALLOC:.*]] = memref.alloc
680680// CHECK-SAME: memref<8x?xf32>
681- // CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
682- // CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
681+ // CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
683682// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
684- // CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
685- // CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
683+ // CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
686684// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
687685// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
688686// CHECK: return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
706704// CHECK: %[[ALLOC:.*]] = memref.alloc
707705// CHECK-SAME: memref<?x?xf32>
708706// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
709- // CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]] ] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
707+ // CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0 ] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
710708// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
711- // CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
712- // CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
709+ // CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
713710// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
714711// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
715712// CHECK: return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
721718
722719// -----
723720
721+ // CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
722+
723+ // CHECK-LABEL: func @tensor.concat_mixed_dynamic_static(
724+ // CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
725+ // CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>)
726+ // CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
727+ // CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
728+ // CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
729+ // CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
730+ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
731+ // CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
732+ // CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
733+ // CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
734+ // CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
735+ // CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
736+ // CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
737+ // CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
738+ // CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
739+ // CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
740+ // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
741+ // CHECK: return %[[RET]]
742+ // CHECK: }
743+ func.func @tensor.concat_mixed_dynamic_static (%f: tensor <8 x?xf32 >, %g: tensor <8 x?xf32 >, %h: tensor <8 x2 xf32 >) -> tensor <8 x10 xf32 > {
744+ %0 = tensor.concat dim (1 ) %f , %g , %h : (tensor <8 x?xf32 >, tensor <8 x?xf32 >, tensor <8 x2 xf32 >) -> tensor <8 x10 xf32 >
745+ return %0 : tensor <8 x10 xf32 >
746+ }
747+
748+ // -----
749+
724750// CHECK-LABEL: func @tensor.splat_dynamic(
725751// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
726752// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
0 commit comments