@@ -159,34 +159,34 @@ static void concat_f32_sycl_non_cont(
159159}
160160
161161static void ggml_sycl_op_concat (ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
162- const ggml_tensor *src0 = dst->src [0 ] ;
163- const ggml_tensor *src1 = dst->src [ 1 ] ;
164- queue_ptr stream = ctx. stream () ;
165- SYCL_CHECK ( ggml_sycl_set_device (ctx. device )) ;
166-
167- const int32_t dim = (( int32_t *)dst-> op_params )[ 0 ] ;
168-
169- if ( ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
170- const float *src0_d = ( const float *)src0-> data ;
171- const float *src1_d = ( const float *) src1-> data ;
172-
173- float *dst_d = (float *)dst ->data ;
174-
175- if (dim != 3 ) {
176- for ( int i3 = 0 ; i3 < dst-> ne [ 3 ]; i3++) {
177- concat_f32_sycl (
178- src0_d + i3 * (src0-> nb [ 3 ] / 4 ), src1_d + i3 * (src1-> nb [3 ] / 4 ),
179- dst_d + i3 * (dst ->nb [3 ] / 4 ), src0-> ne [ 0 ], src0-> ne [ 1 ] ,
180- src0-> ne [ 2 ], dst ->ne [0 ], dst ->ne [1 ], dst ->ne [2 ], dim, stream);
181- }
182- } else {
183- const size_t size0 = ggml_nbytes (src0);
184- const size_t size1 = ggml_nbytes (src1 );
185-
186- SYCL_CHECK ( CHECK_TRY_ERROR (stream-> memcpy (dst_d, src0_d, size0). wait ()));
187- SYCL_CHECK (CHECK_TRY_ERROR (
188- stream->memcpy (dst_d + size0 / 4 , src1_d, size1).wait ()));
189- }
162+ GGML_ASSERT (! ggml_backend_buffer_is_sycl_split ( dst->src [1 ]-> buffer )) ;
163+ GGML_ASSERT (! ggml_backend_buffer_is_sycl_split ( dst->buffer )) ;
164+ const ggml_tensor * src0 = dst-> src [ 0 ] ;
165+ const ggml_tensor * src1 = dst-> src [ 1 ] ;
166+ queue_ptr stream = ctx. stream ();
167+ SYCL_CHECK ( ggml_sycl_set_device (ctx. device )) ;
168+
169+ const int32_t dim = (( int32_t *) dst-> op_params )[ 0 ];
170+
171+ if ( ggml_is_contiguous (src0) && ggml_is_contiguous ( src1)) {
172+ const float * src0_d = ( const float *) src0-> data ;
173+ const float * src1_d = (const float *) src1 ->data ;
174+
175+ float * dst_d = ( float *) dst-> data ;
176+
177+ if (dim != 3 ) {
178+ for ( int i3 = 0 ; i3 < dst-> ne [3 ]; i3++) {
179+ concat_f32_sycl (src0_d + i3 * (src0 ->nb [3 ] / 4 ), src1_d + i3 * (src1-> nb [ 3 ] / 4 ) ,
180+ dst_d + i3 * (dst-> nb [ 3 ] / 4 ), src0 ->ne [0 ], src0 ->ne [1 ], src0 ->ne [2 ], dst-> ne [ 0 ],
181+ dst-> ne [ 1 ], dst-> ne [ 2 ], dim, stream);
182+ }
183+ } else {
184+ const size_t size0 = ggml_nbytes (src0 );
185+ const size_t size1 = ggml_nbytes (src1);
186+
187+ SYCL_CHECK (CHECK_TRY_ERROR (stream-> memcpy (dst_d, src0_d, size0). wait ()));
188+ SYCL_CHECK ( CHECK_TRY_ERROR ( stream->memcpy (dst_d + size0 / 4 , src1_d, size1).wait ()));
189+ }
190190 } else
191191 concat_f32_sycl_non_cont (
192192 stream, (const char *)src0->data , (const char *)src1->data ,
0 commit comments