1313#include " concat.hpp"
1414#include " common.hpp"
1515
16- static void concat_f32_dim0 (const float *x, const float *y, float *dst,
16+ static inline size_t elem_size (ggml_type t) {
17+ return ggml_type_size (t) / ggml_blck_size (t);
18+ }
19+ template <typename T>
20+ static void concat_T_dim0 (const T *x, const T *y, T *dst,
1721 const int ne0, const int ne00,
1822 const sycl::nd_item<3 > &item_ct1) {
1923 int nidx = item_ct1.get_local_id (2 ) +
@@ -36,7 +40,8 @@ static void concat_f32_dim0(const float *x, const float *y, float *dst,
3640 }
3741}
3842
39- static void concat_f32_dim1 (const float *x, const float *y, float *dst,
43+ template <typename T>
44+ static void concat_T_dim1 (const T *x, const T *y, T *dst,
4045 const int ne0, const int ne01,
4146 const sycl::nd_item<3 > &item_ct1) {
4247 int nidx = item_ct1.get_local_id (2 ) +
@@ -59,7 +64,8 @@ static void concat_f32_dim1(const float *x, const float *y, float *dst,
5964 }
6065}
6166
62- static void concat_f32_dim2 (const float *x, const float *y, float *dst,
67+ template <typename T>
68+ static void concat_T_dim2 (const T *x, const T *y, T *dst,
6369 const int ne0, const int ne02,
6470 const sycl::nd_item<3 > &item_ct1) {
6571 int nidx = item_ct1.get_local_id (2 ) +
@@ -82,45 +88,38 @@ static void concat_f32_dim2(const float *x, const float *y, float *dst,
8288 }
8389}
8490
85- static void concat_f32_sycl (const float *x, const float *y, float *dst,
91+ template <typename T>
92+ static void concat_T_sycl (const T *x, const T *y, T *dst,
8693 int ne00, int ne01, int ne02, int ne0, int ne1,
8794 int ne2, int dim, queue_ptr stream) {
8895 int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1 ) / SYCL_CONCAT_BLOCK_SIZE;
8996 sycl::range<3 > gridDim (ne2, ne1, num_blocks);
9097 switch (dim) {
9198 case 0 :
92- stream->parallel_for (
93- sycl::nd_range<3 >(gridDim *
94- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
95- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
96- [=](sycl::nd_item<3 > item_ct1) {
97- concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1);
98- });
99- break ;
99+ sycl_parallel_for (stream,
100+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
101+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
102+ [=](sycl::nd_item<3 > item_ct1) { concat_T_dim0<T>(x, y, dst, ne0, ne00, item_ct1); });
103+ break ;
100104 case 1 :
101- stream->parallel_for (
102- sycl::nd_range<3 >(gridDim *
103- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
104- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
105- [=](sycl::nd_item<3 > item_ct1) {
106- concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1);
107- });
108- break ;
105+ sycl_parallel_for (stream,
106+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
107+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
108+ [=](sycl::nd_item<3 > item_ct1) { concat_T_dim1<T>(x, y, dst, ne0, ne01, item_ct1); });
109+ break ;
109110 // dim >=2 will be dispatched to the default path
110111 default :
111- stream->parallel_for (
112- sycl::nd_range<3 >(gridDim *
113- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
114- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
115- [=](sycl::nd_item<3 > item_ct1) {
116- concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1);
117- });
118- break ;
112+ sycl_parallel_for (stream,
113+ sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
114+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
115+ [=](sycl::nd_item<3 > item_ct1) { concat_T_dim2<T>(x, y, dst, ne0, ne02, item_ct1); });
116+ break ;
119117 }
120118}
121119
122120// non-contiguous kernel (slow)
123- static void concat_f32_sycl_non_cont (
121+ template <typename T>
122+ static void concat_T_sycl_non_cont (
124123 queue_ptr stream, const char *src0, const char *src1, char *dst,
125124 int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
126125 uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /* ne10*/ ,
@@ -129,32 +128,33 @@ static void concat_f32_sycl_non_cont(
129128 int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130129 uint64_t nb3, int32_t dim) {
131130 sycl::range<3 > gridDim (ne3, ne2, ne1);
132- stream-> parallel_for ( sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
131+ sycl_parallel_for (stream, sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
133132 int64_t i3 = item_ct1.get_group (0 );
134133 int64_t i2 = item_ct1.get_group (1 );
135134 int64_t i1 = item_ct1.get_group (2 );
136135
137136 int64_t o[4 ] = { 0 , 0 , 0 , 0 };
138137 o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
139138
140- const float * x;
139+ const T * x;
141140
142141 for (int i0 = item_ct1.get_local_id (2 ); i0 < ne0; i0 += item_ct1.get_local_range (2 )) {
143142 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
144- x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
143+ x = (const T *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
145144 } else {
146- x = (const float *) (src1 + (i3 - o[3 ]) * nb13 + (i2 - o[2 ]) * nb12 + (i1 - o[1 ]) * nb11 +
145+ x = (const T *) (src1 + (i3 - o[3 ]) * nb13 + (i2 - o[2 ]) * nb12 + (i1 - o[1 ]) * nb11 +
147146 (i0 - o[0 ]) * nb10);
148147 }
149148
150- float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
149+ T *y = (T *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
151150
152151 *y = *x;
153152 }
154153 });
155154}
156155
157- void ggml_sycl_op_concat (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
156+ template <typename T>
157+ void concat_impl_sycl (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
158158 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 2 );
159159 const ggml_tensor * src0 = dst->src [0 ];
160160 const ggml_tensor * src1 = dst->src [1 ];
@@ -163,29 +163,55 @@ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
163163 const int32_t dim = ((int32_t *) dst->op_params )[0 ];
164164
165165 if (ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
166- const float * src0_d = (const float *) src0->data ;
167- const float * src1_d = (const float *) src1->data ;
168-
169- float * dst_d = (float *) dst->data ;
166+ const T * src0_d = (const T *) src0->data ;
167+ const T * src1_d = (const T *) src1->data ;
170168
169+ T * dst_d = (T *) dst->data ;
170+
171+ size_t type_size = elem_size (dst->type );
172+
171173 if (dim != 3 ) {
172174 for (int i3 = 0 ; i3 < dst->ne [3 ]; i3++) {
173- concat_f32_sycl (src0_d + i3 * (src0->nb [3 ] / 4 ), src1_d + i3 * (src1->nb [3 ] / 4 ),
174- dst_d + i3 * (dst->nb [3 ] / 4 ), src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], dst->ne [0 ],
175+ concat_T_sycl<T> (src0_d + i3 * (src0->nb [3 ] / type_size ), src1_d + i3 * (src1->nb [3 ] / type_size ),
176+ dst_d + i3 * (dst->nb [3 ] / type_size ), src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], dst->ne [0 ],
175177 dst->ne [1 ], dst->ne [2 ], dim, stream);
176178 }
177179 } else {
178180 const size_t size0 = ggml_nbytes (src0);
179181 const size_t size1 = ggml_nbytes (src1);
180182
181183 SYCL_CHECK (CHECK_TRY_ERROR (stream->memcpy (dst_d, src0_d, size0).wait ()));
182- SYCL_CHECK (CHECK_TRY_ERROR (stream->memcpy (dst_d + size0 / 4 , src1_d, size1).wait ()));
184+ SYCL_CHECK (CHECK_TRY_ERROR (stream->memcpy (dst_d + size0 / type_size , src1_d, size1).wait ()));
183185 }
184186 } else {
185- concat_f32_sycl_non_cont (stream, (const char *) src0->data , (const char *) src1->data , (char *) dst->data ,
187+ concat_T_sycl_non_cont<T> (stream, (const char *) src0->data , (const char *) src1->data , (char *) dst->data ,
186188 src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ],
187189 src0->nb [2 ], src0->nb [3 ], src1->ne [0 ], src1->ne [1 ], src1->ne [2 ], src1->ne [3 ],
188190 src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ],
189191 dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ], dim);
190192 }
191193}
194+
195+ void ggml_sycl_op_concat (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
196+
197+ static std::atomic<bool > printed{false };
198+ if (!printed.exchange (true )) std::fprintf (stderr, " [LP] hit ggml_sycl_op_concat\n " );
199+
200+ LP_PROFILE_INIT_ONCE ();
201+ LP_PROFILE_PAIR (dst->src [0 ], dst->src [1 ]);
202+
203+ switch (dst->type ) {
204+ case GGML_TYPE_F32:
205+ concat_impl_sycl<float >(ctx, dst);
206+ break ;
207+ case GGML_TYPE_I16:
208+ concat_impl_sycl<int16_t >(ctx, dst);
209+ break ;
210+ case GGML_TYPE_I32:
211+ concat_impl_sycl<int32_t >(ctx, dst);
212+ break ;
213+ default :
214+ GGML_ASSERT (false && " ggml_sycl_op_concat: unsupported type" );
215+ break ;
216+ }
217+ }
0 commit comments