@@ -159,39 +159,37 @@ static void concat_f32_sycl_non_cont(
159159}
160160
161161void ggml_sycl_op_concat (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
162- const ggml_tensor *src0 = dst->src [0 ];
163- const ggml_tensor *src1 = dst->src [1 ];
164- queue_ptr stream = ctx.stream ();
165-
166- const int32_t dim = ((int32_t *)dst->op_params )[0 ];
167-
168- if (ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
169- const float *src0_d = (const float *)src0->data ;
170- const float *src1_d = (const float *)src1->data ;
171-
172- float *dst_d = (float *)dst->data ;
173-
174- if (dim != 3 ) {
175- for (int i3 = 0 ; i3 < dst->ne [3 ]; i3++) {
176- concat_f32_sycl (
177- src0_d + i3 * (src0->nb [3 ] / 4 ), src1_d + i3 * (src1->nb [3 ] / 4 ),
178- dst_d + i3 * (dst->nb [3 ] / 4 ), src0->ne [0 ], src0->ne [1 ],
179- src0->ne [2 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dim, stream);
180- }
162+ scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 2 );
163+ const ggml_tensor * src0 = dst->src [0 ];
164+ const ggml_tensor * src1 = dst->src [1 ];
165+ queue_ptr stream = ctx.stream ();
166+
167+ const int32_t dim = ((int32_t *) dst->op_params )[0 ];
168+
169+ if (ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
170+ const float * src0_d = (const float *) src0->data ;
171+ const float * src1_d = (const float *) src1->data ;
172+
173+ float * dst_d = (float *) dst->data ;
174+
175+ if (dim != 3 ) {
176+ for (int i3 = 0 ; i3 < dst->ne [3 ]; i3++) {
177+ concat_f32_sycl (src0_d + i3 * (src0->nb [3 ] / 4 ), src1_d + i3 * (src1->nb [3 ] / 4 ),
178+ dst_d + i3 * (dst->nb [3 ] / 4 ), src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], dst->ne [0 ],
179+ dst->ne [1 ], dst->ne [2 ], dim, stream);
180+ }
181+ } else {
182+ const size_t size0 = ggml_nbytes (src0);
183+ const size_t size1 = ggml_nbytes (src1);
184+
185+ SYCL_CHECK (CHECK_TRY_ERROR (stream->memcpy (dst_d, src0_d, size0).wait ()));
186+ SYCL_CHECK (CHECK_TRY_ERROR (stream->memcpy (dst_d + size0 / 4 , src1_d, size1).wait ()));
187+ }
181188 } else {
182- const size_t size0 = ggml_nbytes (src0);
183- const size_t size1 = ggml_nbytes (src1);
184-
185- SYCL_CHECK (CHECK_TRY_ERROR (stream->memcpy (dst_d, src0_d, size0).wait ()));
186- SYCL_CHECK (CHECK_TRY_ERROR (
187- stream->memcpy (dst_d + size0 / 4 , src1_d, size1).wait ()));
189+ concat_f32_sycl_non_cont (stream, (const char *) src0->data , (const char *) src1->data , (char *) dst->data ,
190+ src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ],
191+ src0->nb [2 ], src0->nb [3 ], src1->ne [0 ], src1->ne [1 ], src1->ne [2 ], src1->ne [3 ],
192+ src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ],
193+ dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ], dim);
188194 }
189- } else
190- concat_f32_sycl_non_cont (
191- stream, (const char *)src0->data , (const char *)src1->data ,
192- (char *)dst->data , src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
193- src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
194- src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ],
195- src1->nb [2 ], src1->nb [3 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ],
196- dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ], dim);
197195}
0 commit comments