Skip to content

Commit b456e69

Browse files
author
Gitty Burstein
committed
SYCL: update element-wise ops and presets
1 parent aa0c461 commit b456e69

File tree

4 files changed

+45
-0
lines changed

4 files changed

+45
-0
lines changed

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
417417
});
418418
}
419419

420+
template<typename T>
421+
static void arange_kernel(T * dst, const int k, T start, T step,
422+
const sycl::nd_item<1> &item_ct1) {
423+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
424+
dst[i] = start + static_cast<T>(i) * step;
425+
}
426+
}
427+
420428
template<typename T>
421429
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
422430
const int nb02, const int nb03, const int ne10, const int ne11,
@@ -631,6 +639,30 @@ static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, gg
631639
}
632640
}
633641

642+
// ב-namespace ggml_sycl_detail:
643+
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
644+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
645+
646+
float start, stop, step;
647+
memcpy(&start, dst->op_params, sizeof(float));
648+
memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
649+
memcpy(&step, (float *) dst->op_params + 2, sizeof(float));
650+
651+
dpct::queue_ptr stream = ctx.stream();
652+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
653+
654+
float * dst_ptr = (float *)dst->data;
655+
const int k = (int)ggml_nelements(dst); // הוספה חשובה!
656+
657+
const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
658+
stream->parallel_for(
659+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
660+
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
661+
[=](sycl::nd_item<1> item_ct1) {
662+
arange_kernel(dst_ptr, k, start, step, item_ct1);
663+
});
664+
}
665+
634666
} // namespace ggml_sycl_detail
635667

636668

@@ -1168,3 +1200,8 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11681200
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
11691201
ggml_sycl_op_geglu_quick(ctx, dst);
11701202
}
1203+
1204+
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1205+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
1206+
ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
1207+
}

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,6 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8383
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8484
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8585

86+
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
87+
8688
#endif // GGML_SYCL_ELEMENTWISE_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3768,6 +3768,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
37683768
case GGML_OP_GATED_LINEAR_ATTN:
37693769
ggml_sycl_op_gated_linear_attn(ctx, dst);
37703770
break;
3771+
case GGML_OP_ARANGE:
3772+
ggml_sycl_arange(ctx, dst);
3773+
break;
37713774
default:
37723775
return false;
37733776
}
@@ -4416,6 +4419,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44164419
case GGML_OP_RWKV_WKV7:
44174420
case GGML_OP_GATED_LINEAR_ATTN:
44184421
return true;
4422+
case GGML_OP_ARANGE:
4423+
return op->type == GGML_TYPE_F32;
44194424
default:
44204425
return false;
44214426
}

ggml/src/ggml-sycl/presets.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#define SYCL_ARGMAX_BLOCK_SIZE 256
5050
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
5151
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
52+
#define SYCL_ARANGE_BLOCK_SIZE 256
5253

5354
// dmmv = dequantize_mul_mat_vec
5455
#ifndef GGML_SYCL_DMMV_X

0 commit comments

Comments
 (0)