Skip to content

Commit b22572e

Browse files
GittyBursteinGitty Burstein
andauthored
sycl : add ARANGE operator (ggml-org#16362)
* SYCL: update element-wise ops and presets * clean arange * Re-trigger CI --------- Co-authored-by: Gitty Burstein <[email protected]>
1 parent 7a50cf3 commit b22572e

File tree

4 files changed

+40
-0
lines changed

4 files changed

+40
-0
lines changed

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
397397
});
398398
}
399399

400+
template<typename T>
401+
static void arange_kernel(T * dst, const int k, T start, T step,
402+
const sycl::nd_item<1> &item_ct1) {
403+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
404+
dst[i] = start + static_cast<T>(i) * step;
405+
}
406+
}
407+
400408
template<typename T>
401409
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
402410
const int nb02, const int nb03, const int ne10, const int ne11,
@@ -565,6 +573,25 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
565573
}
566574

567575

576+
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
577+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
578+
float start, stop, step;
579+
memcpy(&start, dst->op_params, sizeof(float));
580+
memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
581+
memcpy(&step, (float *) dst->op_params + 2, sizeof(float));
582+
dpct::queue_ptr stream = ctx.stream();
583+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
584+
float * dst_ptr = (float *)dst->data;
585+
const int k = (int)ggml_nelements(dst);
586+
const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
587+
stream->parallel_for(
588+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
589+
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
590+
[=](sycl::nd_item<1> item_ct1) {
591+
arange_kernel(dst_ptr, k, start, step, item_ct1);
592+
});
593+
}
594+
568595
} // namespace ggml_sycl_detail
569596

570597

@@ -1090,3 +1117,8 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10901117
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
10911118
ggml_sycl_op_geglu_quick(ctx, dst);
10921119
}
1120+
1121+
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1122+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
1123+
ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
1124+
}

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,6 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8181
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8282
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8383

84+
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
85+
8486
#endif // GGML_SYCL_ELEMENTWISE_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3832,6 +3832,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
38323832
case GGML_OP_GATED_LINEAR_ATTN:
38333833
ggml_sycl_op_gated_linear_attn(ctx, dst);
38343834
break;
3835+
case GGML_OP_ARANGE:
3836+
ggml_sycl_arange(ctx, dst);
3837+
break;
38353838
default:
38363839
return false;
38373840
}
@@ -4478,6 +4481,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44784481
case GGML_OP_RWKV_WKV7:
44794482
case GGML_OP_GATED_LINEAR_ATTN:
44804483
return true;
4484+
case GGML_OP_ARANGE:
4485+
return op->type == GGML_TYPE_F32;
44814486
default:
44824487
return false;
44834488
}

ggml/src/ggml-sycl/presets.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#define SYCL_ARGMAX_BLOCK_SIZE 256
5050
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
5151
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
52+
#define SYCL_ARANGE_BLOCK_SIZE 256
5253

5354
// dmmv = dequantize_mul_mat_vec
5455
#ifndef GGML_SYCL_DMMV_X

0 commit comments

Comments
 (0)