1+ kernel void kernel_concat_f32_contiguous (
2+ global const char * p_src0 , ulong off_src0 ,
3+ global const char * p_src1 , ulong off_src1 ,
4+ global char * p_dst , ulong off_dst ,
5+ int d_ne00 , int d_ne01 , int d_ne02 , // src0->ne[0..2] for the slice
6+ int d_ne10 , int d_ne11 , int d_ne12 , // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes)
7+ int d_ne0 , int d_ne1 , int d_ne2 , // dst->ne[0..2] for the slice
8+ int dim
9+ ) {
10+ global const float * src0 = (global const float * )((global char * )p_src0 + off_src0 );
11+ global const float * src1 = (global const float * )((global char * )p_src1 + off_src1 );
12+ global float * dst = (global float * )((global char * )p_dst + off_dst );
13+
14+ int i0 = get_global_id (0 ); // Index along dst's 0th dimension
15+ int i1 = get_global_id (1 ); // Index along dst's 1st dimension
16+ int i2 = get_global_id (2 ); // Index along dst's 2nd dimension
17+
18+ if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2 ) {
19+ return ;
20+ }
21+
22+ ulong dst_idx = (ulong )i2 * d_ne0 * d_ne1 + (ulong )i1 * d_ne0 + i0 ;
23+ ulong src_idx ;
24+
25+ if (dim == 0 ) {
26+ if (i0 < d_ne00 ) { // Data from src0
27+ src_idx = (ulong )i2 * d_ne00 * d_ne01 + (ulong )i1 * d_ne00 + i0 ;
28+ dst [dst_idx ] = src0 [src_idx ];
29+ } else { // Data from src1
30+ src_idx = (ulong )i2 * d_ne10 * d_ne11 + (ulong )i1 * d_ne10 + (i0 - d_ne00 );
31+ dst [dst_idx ] = src1 [src_idx ];
32+ }
33+ } else if (dim == 1 ) {
34+ if (i1 < d_ne01 ) { // Data from src0
35+ src_idx = (ulong )i2 * d_ne00 * d_ne01 + (ulong )i1 * d_ne00 + i0 ;
36+ dst [dst_idx ] = src0 [src_idx ];
37+ } else { // Data from src1
38+ src_idx = (ulong )i2 * d_ne10 * d_ne11 + (ulong )(i1 - d_ne01 ) * d_ne10 + i0 ;
39+ dst [dst_idx ] = src1 [src_idx ];
40+ }
41+ } else if (dim == 2 ) {
42+ if (i2 < d_ne02 ) { // Data from src0
43+ src_idx = (ulong )i2 * d_ne00 * d_ne01 + (ulong )i1 * d_ne00 + i0 ;
44+ dst [dst_idx ] = src0 [src_idx ];
45+ } else { // Data from src1
46+
47+ src_idx = (ulong )(i2 - d_ne02 ) * d_ne10 * d_ne11 + (ulong )i1 * d_ne10 + i0 ;
48+ dst [dst_idx ] = src1 [src_idx ];
49+ }
50+ }
51+ }
52+
53+ kernel void kernel_concat_f32_non_contiguous (
54+ global const char * p_src0 , ulong off_src0 ,
55+ global const char * p_src1 , ulong off_src1 ,
56+ global char * p_dst , ulong off_dst ,
57+
58+ long ne00 , long ne01 , long ne02 , long ne03 ,
59+ ulong nb00 , ulong nb01 , ulong nb02 , ulong nb03 ,
60+
61+ ulong nb10 , ulong nb11 , ulong nb12 , ulong nb13 , // Strides for src1
62+
63+ long d_ne0 , long d_ne1 , long d_ne2 , long d_ne3 ,
64+ ulong d_nb0 , ulong d_nb1 , ulong d_nb2 , ulong d_nb3 ,
65+ int dim
66+ ) {
67+ global const char * src0_base = p_src0 + off_src0 ;
68+ global const char * src1_base = p_src1 + off_src1 ;
69+ global char * dst_base = p_dst + off_dst ;
70+
71+ long current_i1 = get_global_id (0 ); // Index for dst_dim_1
72+ long current_i2 = get_global_id (1 ); // Index for dst_dim_2
73+ long current_i3 = get_global_id (2 ); // Index for dst_dim_3
74+
75+ if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3 ) {
76+ return ;
77+ }
78+
79+ global const float * x_val_ptr ;
80+ global float * y_val_ptr ;
81+
82+ for (long current_i0 = 0 ; current_i0 < d_ne0 ; ++ current_i0 ) {
83+ bool use_src0 ;
84+ long s_i0 = current_i0 , s_i1 = current_i1 , s_i2 = current_i2 , s_i3 = current_i3 ;
85+
86+ if (dim == 0 ) {
87+ use_src0 = (current_i0 < ne00 );
88+ if (!use_src0 ) { s_i0 = current_i0 - ne00 ; }
89+ } else if (dim == 1 ) {
90+ use_src0 = (current_i1 < ne01 );
91+ if (!use_src0 ) { s_i1 = current_i1 - ne01 ; }
92+ } else if (dim == 2 ) {
93+ use_src0 = (current_i2 < ne02 );
94+ if (!use_src0 ) { s_i2 = current_i2 - ne02 ; }
95+ } else { // dim == 3
96+ use_src0 = (current_i3 < ne03 );
97+ if (!use_src0 ) { s_i3 = current_i3 - ne03 ; }
98+ }
99+
100+ if (use_src0 ) {
101+ x_val_ptr = (global const float * )(src0_base + (ulong )s_i3 * nb03 + (ulong )s_i2 * nb02 + (ulong )s_i1 * nb01 + (ulong )s_i0 * nb00 );
102+ } else {
103+ x_val_ptr = (global const float * )(src1_base + (ulong )s_i3 * nb13 + (ulong )s_i2 * nb12 + (ulong )s_i1 * nb11 + (ulong )s_i0 * nb10 );
104+ }
105+
106+ y_val_ptr = (global float * )(dst_base + (ulong )current_i3 * d_nb3 + (ulong )current_i2 * d_nb2 + (ulong )current_i1 * d_nb1 + (ulong )current_i0 * d_nb0 );
107+ * y_val_ptr = * x_val_ptr ;
108+ }
109+ }
0 commit comments