Skip to content

Commit 10ce7fb

Browse files
committed
add imrope support for sycl
1 parent 6a0191a commit 10ce7fb

File tree

1 file changed

+30
-17
lines changed

1 file changed

+30
-17
lines changed

ggml/src/ggml-sycl/rope.cpp

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
119119
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
120120
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
121121
const float theta_scale, const float * freq_factors, const mrope_sections sections,
122-
const sycl::nd_item<3> & item_ct1) {
122+
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
123123
// get index pos
124124
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
125125
if (i0 >= ne0) {
@@ -143,17 +143,29 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
143143

144144

145145
float theta_base = 0.0;
146-
if (sector < sections.v[0]) {
147-
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
148-
}
149-
else if (sector >= sections.v[0] && sector < sec_w) {
150-
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
151-
}
152-
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
153-
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
154-
}
155-
else if (sector >= sec_w + sections.v[2]) {
156-
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
146+
if (is_imrope) {
147+
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
148+
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
149+
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
150+
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
151+
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
152+
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
153+
} else {
154+
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
155+
}
156+
} else {
157+
if (sector < sections.v[0]) {
158+
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
159+
}
160+
else if (sector >= sections.v[0] && sector < sec_w) {
161+
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
162+
}
163+
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
164+
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
165+
}
166+
else if (sector >= sec_w + sections.v[2]) {
167+
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
168+
}
157169
}
158170

159171
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
281293
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
282294
const float freq_scale, const float freq_base, const float ext_factor,
283295
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
284-
const mrope_sections sections, queue_ptr stream) {
296+
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
285297
GGML_ASSERT(ne0 % 2 == 0);
286298
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
287299
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
@@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
297309
if (freq_factors == nullptr) {
298310
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
299311
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
300-
corr_dims, theta_scale, freq_factors, sections, item_ct1);
312+
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
301313
});
302314
} else {
303315
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
304316
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
305-
corr_dims, theta_scale, freq_factors, sections, item_ct1);
317+
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
306318
});
307319
}
308320
}
@@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
381393

382394
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
383395
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
396+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
384397
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
385398

386399
if (is_mrope) {
@@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
422435
if (dst->src[0]->type == GGML_TYPE_F16) {
423436
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
424437
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
425-
freq_factors, sections, main_stream);
438+
freq_factors, sections, is_imrope, main_stream);
426439
} else if (dst->src[0]->type == GGML_TYPE_F32) {
427440
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
428441
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
429-
main_stream);
442+
is_imrope, main_stream);
430443
} else {
431444
GGML_ABORT("Fatal error: Tensor type unsupported!");
432445
}

0 commit comments

Comments
 (0)