Skip to content

Commit 31af27a

Browse files
committed
cont : fix multi-rope + add test
ggml-ci
1 parent 2a9b730 commit 31af27a

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

ggml/src/ggml-cuda/rope.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,21 @@ static __global__ void rope_multi(
138138

139139
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
140140

141+
const int row_x = row_dst % ne1;
142+
const int channel_x = row_dst / ne1;
143+
144+
const int idst = row_dst*ne0 + i0/2;
145+
const int ix = channel_x*s2 + row_x*s1 + i0/2;
146+
141147
if (i0 >= n_dims) {
142148
const int i = row_dst*ne0 + i0;
143149

144-
dst[i + 0] = x[i + 0];
145-
dst[i + 1] = x[i + 1];
150+
dst[i + 0] = x[ix + i0/2 + 0];
151+
dst[i + 1] = x[ix + i0/2 + 1];
146152

147153
return;
148154
}
149155

150-
const int row_x = row_dst % ne1;
151-
const int channel_x = row_dst / ne1;
152-
153-
const int idst = row_dst*ne0 + i0/2;
154-
const int ix = channel_x*s2 + row_x*s1 + i0/2;
155-
156156
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
157157
const int sec_w = sections.v[1] + sections.v[0];
158158
const int sector = (i0 / 2) % sect_dims;

tests/test-backend-ops.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5342,9 +5342,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53425342
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
53435343
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
53445344

5345-
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
5346-
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
5347-
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
5345+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw));
5346+
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
5347+
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
53485348

53495349
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
53505350
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
@@ -5354,6 +5354,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53545354
if (all) {
53555355
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
53565356
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
5357+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
5358+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
53575359
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
53585360
}
53595361

0 commit comments

Comments
 (0)