Skip to content

Commit 96998d7

Browse files
committed
sycl : try fix
ggml-ci
1 parent 31af27a commit 96998d7

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

ggml/src/ggml-sycl/rope.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,18 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
4747

4848
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
4949

50-
if (i0 >= n_dims) {
51-
const int i = row * ne0 + i0;
52-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
53-
return;
54-
}
55-
5650
const int row0 = row % ne1;
5751
const int channel0 = row / ne1;
5852

5953
const int i = row * ne0 + i0;
6054
const int i2 = channel0 * s2 + row0 * s1 + i0;
6155

56+
if (i0 >= n_dims) {
57+
const int i = row * ne0 + i0;
58+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
59+
return;
60+
}
61+
6262
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
6363

6464
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -88,18 +88,18 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
8888

8989
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
9090

91-
if (i0 >= n_dims) {
92-
const int i = row * ne0 + i0;
93-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
94-
return;
95-
}
96-
9791
const int row0 = row % ne1;
9892
const int channel0 = row / ne1;
9993

10094
const int i = row * ne0 + i0 / 2;
10195
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
10296

97+
if (i0 >= n_dims) {
98+
const int i = row * ne0 + i0;
99+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + i);
100+
return;
101+
}
102+
103103
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
104104

105105
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -129,17 +129,17 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
129129
}
130130
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
131131

132-
if (i0 >= n_dims) {
133-
const int i = row_dst*ne0 + i0;
134-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
135-
return;
136-
}
137-
138132
const int row_x = row_dst % ne1;
139133
const int channel_x = row_dst / ne1;
140134
const int idst = (row_dst * ne0) + (i0 / 2);
141135
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
142136

137+
if (i0 >= n_dims) {
138+
const int i = row_dst*ne0 + i0;
139+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + i);
140+
return;
141+
}
142+
143143
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
144144
const int sec_w = sections.v[1] + sections.v[0];
145145
const int sector = (i0 / 2) % sect_dims;

0 commit comments

Comments
 (0)