Skip to content

Commit a2d24c9

Browse files
ikawrakowIwan Kawrakow
andauthored
TG improvements for MoE models (#404)
* cuda: Remove unnecessary device to host copy of row ids We get 3-4% TG speed improvement for DeepSeek-Lite just from that. * CPU: fix get_rows when SER is used With smart experts reduction (SER), one potentially uses fewer experts than specified by the model. This is accomplished by setting the ID of the not seected tensors to -1. Most of the necessary stuff was implemented when I added the SER option, but I forgot to update get_rows() for not quantized tensors. As a result, we get random garbage for the weights of the not-selected epxerts, which leads to garbage output. This commit fixes it on the CPU. I'm not quite sure yet why the GPU is not working. * CUDA: fix TG with SER --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 43a154d commit a2d24c9

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,11 +2505,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
25052505
dst_padded_col_size, next->src[0]->type, stream);
25062506
CUDA_CHECK(cudaGetLastError());
25072507

2508-
std::vector<char> ids_host(ggml_nbytes(ids));
2509-
const char * ids_dev = (const char *) ids->data;
2510-
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2511-
CUDA_CHECK(cudaStreamSynchronize(stream));
2512-
25132508
local_dst.ne[2] = 1;
25142509

25152510
auto local_next = *next;

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,27 @@ static __global__ void mul_mat_vec_q(
147147
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst,
148148
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) {
149149
int i2 = blockIdx.y;
150+
char * cdst = (char *)dst + i2*nb2;
150151
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
152+
if (i02 < 0) {
153+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
154+
constexpr int rows_per_cuda_block = 1;
155+
#else
156+
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
157+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
158+
const int row0 = rows_per_cuda_block*blockIdx.x;
159+
if (threadIdx.y == 0) {
160+
dst = (float *)cdst;
161+
for (int j = 0; j < ncols_y; ++j) {
162+
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
163+
dst[j*nrows_dst + row0 + threadIdx.x] = 0;
164+
}
165+
}
166+
}
167+
return;
168+
}
151169
const char * cx = (const char *)vx + i02*nb02;
152170
const char * cy = (const char *)vy + i2*nb12;
153-
char * cdst = (char *)dst + i2*nb2;
154171
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
155172
}
156173

ggml/src/ggml.c

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15911,11 +15911,14 @@ static void ggml_compute_forward_get_rows_f16(
1591115911
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
1591215912
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
1591315913

15914-
assert(i01 >= 0 && i01 < ne01);
15914+
if (i01 >= 0 && i01 < ne01) {
15915+
ggml_fp16_to_fp32_row(
15916+
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
15917+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
15918+
} else {
15919+
memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
15920+
}
1591515921

15916-
ggml_fp16_to_fp32_row(
15917-
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
15918-
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
1591915922
}
1592015923
}
1592115924

@@ -15952,11 +15955,13 @@ static void ggml_compute_forward_get_rows_bf16(
1595215955
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
1595315956
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
1595415957

15955-
assert(i01 >= 0 && i01 < ne01);
15956-
15957-
ggml_bf16_to_fp32_row(
15958-
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
15959-
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
15958+
if (i01 >= 0 && i01 < ne01) {
15959+
ggml_bf16_to_fp32_row(
15960+
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
15961+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
15962+
} else {
15963+
memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
15964+
}
1596015965
}
1596115966
}
1596215967

@@ -15993,11 +15998,13 @@ static void ggml_compute_forward_get_rows_f32(
1599315998
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
1599415999
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
1599516000

15996-
assert(i01 >= 0 && i01 < ne01);
15997-
15998-
ggml_vec_cpy_f32(nc,
15999-
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
16000-
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
16001+
if (i01 >= 0 && i01 < ne01) {
16002+
ggml_vec_cpy_f32(nc,
16003+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
16004+
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
16005+
} else {
16006+
memset((char *)dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
16007+
}
1600116008
}
1600216009
}
1600316010

0 commit comments

Comments
 (0)