@@ -417,6 +417,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
417417 });
418418}
419419
420+ template <typename T>
421+ static void arange_kernel (T * dst, const int k, T start, T step,
422+ const sycl::nd_item<1 > &item_ct1) {
423+ SYCL_GLOBAL_ID_LOOP (k, item_ct1) {
424+ dst[i] = start + static_cast <T>(i) * step;
425+ }
426+ }
427+
420428template <typename T>
421429static void upscale_sycl (const T *x, T *dst, const int nb00, const int nb01,
422430 const int nb02, const int nb03, const int ne10, const int ne11,
@@ -631,6 +639,30 @@ static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, gg
631639 }
632640}
633641
642+ // ב-namespace ggml_sycl_detail:
643+ static inline void ggml_sycl_op_arange (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
644+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
645+
646+ float start, stop, step;
647+ memcpy (&start, dst->op_params , sizeof (float ));
648+ memcpy (&stop, (float *) dst->op_params + 1 , sizeof (float ));
649+ memcpy (&step, (float *) dst->op_params + 2 , sizeof (float ));
650+
651+ dpct::queue_ptr stream = ctx.stream ();
652+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
653+
654+ float * dst_ptr = (float *)dst->data ;
655+ const int k = (int )ggml_nelements (dst); // הוספה חשובה!
656+
657+ const int num_blocks = ceil_div (k, SYCL_ARANGE_BLOCK_SIZE);
658+ stream->parallel_for (
659+ sycl::nd_range<1 >(sycl::range<1 >(num_blocks) * sycl::range<1 >(SYCL_ARANGE_BLOCK_SIZE),
660+ sycl::range<1 >(SYCL_ARANGE_BLOCK_SIZE)),
661+ [=](sycl::nd_item<1 > item_ct1) {
662+ arange_kernel (dst_ptr, k, start, step, item_ct1);
663+ });
664+ }
665+
634666} // namespace ggml_sycl_detail
635667
636668
@@ -1168,3 +1200,8 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11681200 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
11691201 ggml_sycl_op_geglu_quick (ctx, dst);
11701202}
1203+
1204+ void ggml_sycl_arange (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1205+ scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 0 );
1206+ ggml_sycl_detail::ggml_sycl_op_arange (ctx, dst);
1207+ }
0 commit comments