Skip to content

Commit 2a9b730

Browse files
committed
cuda : fix rope non-cont
ggml-ci
1 parent 75c91de commit 2a9b730

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

ggml/src/ggml-cuda/rope.cu

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,21 @@ static __global__ void rope_norm(
5050

5151
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
5252

53+
const int row_x = row_dst % ne1;
54+
const int channel_x = row_dst / ne1;
55+
56+
const int idst = row_dst*ne0 + i0;
57+
const int ix = channel_x*s2 + row_x*s1 + i0;
58+
5359
if (i0 >= n_dims) {
5460
const int i = row_dst*ne0 + i0;
5561

56-
dst[i + 0] = x[i + 0];
57-
dst[i + 1] = x[i + 1];
62+
dst[i + 0] = x[ix + 0];
63+
dst[i + 1] = x[ix + 1];
5864

5965
return;
6066
}
6167

62-
const int row_x = row_dst % ne1;
63-
const int channel_x = row_dst / ne1;
64-
65-
const int idst = row_dst*ne0 + i0;
66-
const int ix = channel_x*s2 + row_x*s1 + i0;
67-
6868
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
6969

7070
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -94,21 +94,21 @@ static __global__ void rope_neox(
9494

9595
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
9696

97+
const int row_x = row_dst % ne1;
98+
const int channel_x = row_dst / ne1;
99+
100+
const int idst = row_dst*ne0 + i0/2;
101+
const int ix = channel_x*s2 + row_x*s1 + i0/2;
102+
97103
if (i0 >= n_dims) {
98104
const int i = row_dst*ne0 + i0;
99105

100-
dst[i + 0] = x[i + 0];
101-
dst[i + 1] = x[i + 1];
106+
dst[i + 0] = x[ix + i0/2 + 0];
107+
dst[i + 1] = x[ix + i0/2 + 1];
102108

103109
return;
104110
}
105111

106-
const int row_x = row_dst % ne1;
107-
const int channel_x = row_dst / ne1;
108-
109-
const int idst = row_dst*ne0 + i0/2;
110-
const int ix = channel_x*s2 + row_x*s1 + i0/2;
111-
112112
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
113113

114114
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

tests/test-backend-ops.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5323,12 +5323,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53235323
for (bool fw : {true, false}) { // fw == forward
53245324
bool all = true;
53255325

5326-
for (float v : { 0, 1 }) {
5327-
for (float fs : { 1.0f, 1.4245f }) {
5328-
for (float ef : { 0.0f, 0.7465f }) {
5329-
for (float af : { 1.0f, 1.4245f }) {
5330-
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5331-
for (bool ff : {false, true}) { // freq_factors
5326+
for (float fs : { 1.0f, 1.4245f }) {
5327+
for (float ef : { 0.0f, 0.7465f }) {
5328+
for (float af : { 1.0f, 1.4245f }) {
5329+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5330+
for (bool ff : {false, true}) { // freq_factors
5331+
for (float v : { 0, 1 }) {
53325332
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
53335333

53345334
if (all) {
@@ -5341,8 +5341,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53415341
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
53425342
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
53435343
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
5344+
5345+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
5346+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
5347+
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
5348+
53445349
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
53455350
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
5351+
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
53465352
}
53475353

53485354
if (all) {

0 commit comments

Comments
 (0)