Skip to content

Commit bcbf7bc

Browse files
committed
cont : fix sycl + clean-up cuda
ggml-ci
1 parent 96998d7 commit bcbf7bc

File tree

2 files changed

+8
-17
lines changed

2 files changed

+8
-17
lines changed

ggml/src/ggml-cuda/rope.cu

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ static __global__ void rope_norm(
5757
const int ix = channel_x*s2 + row_x*s1 + i0;
5858

5959
if (i0 >= n_dims) {
60-
const int i = row_dst*ne0 + i0;
61-
62-
dst[i + 0] = x[ix + 0];
63-
dst[i + 1] = x[ix + 1];
60+
dst[idst + 0] = x[ix + 0];
61+
dst[idst + 1] = x[ix + 1];
6462

6563
return;
6664
}
@@ -101,10 +99,8 @@ static __global__ void rope_neox(
10199
const int ix = channel_x*s2 + row_x*s1 + i0/2;
102100

103101
if (i0 >= n_dims) {
104-
const int i = row_dst*ne0 + i0;
105-
106-
dst[i + 0] = x[ix + i0/2 + 0];
107-
dst[i + 1] = x[ix + i0/2 + 1];
102+
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103+
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
108104

109105
return;
110106
}
@@ -145,10 +141,8 @@ static __global__ void rope_multi(
145141
const int ix = channel_x*s2 + row_x*s1 + i0/2;
146142

147143
if (i0 >= n_dims) {
148-
const int i = row_dst*ne0 + i0;
149-
150-
dst[i + 0] = x[ix + i0/2 + 0];
151-
dst[i + 1] = x[ix + i0/2 + 1];
144+
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
145+
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
152146

153147
return;
154148
}

ggml/src/ggml-sycl/rope.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
5454
const int i2 = channel0 * s2 + row0 * s1 + i0;
5555

5656
if (i0 >= n_dims) {
57-
const int i = row * ne0 + i0;
5857
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
5958
return;
6059
}
@@ -95,8 +94,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
9594
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
9695

9796
if (i0 >= n_dims) {
98-
const int i = row * ne0 + i0;
99-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + i);
97+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
10098
return;
10199
}
102100

@@ -135,8 +133,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
135133
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
136134

137135
if (i0 >= n_dims) {
138-
const int i = row_dst*ne0 + i0;
139-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + i);
136+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
140137
return;
141138
}
142139

0 commit comments

Comments
 (0)