diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu index 060550c572..2758fb5c9a 100644 --- a/src/ggml-cuda/ggml-cuda.cu +++ b/src/ggml-cuda/ggml-cuda.cu @@ -1942,35 +1942,21 @@ struct mmid_row_mapping { int32_t i2; }; -static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, - int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, - const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, - int64_t ne11, int64_t ne10, - size_t nb11, size_t nb12) { - int32_t iid1 = blockIdx.x; - int32_t id = blockIdx.y; - - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); +// MIT licensed. Copyright (C) 2025 Iwan Kawrakow +// https://github.com/ikawrakow/ik_llama.cpp/pull/283 +static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous, + const mmid_row_mapping * __restrict__ row_mapping, + int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) { + const int32_t i = blockIdx.x; - if (row_id_i != i02) { - return; - } + const int32_t i11 = row_mapping[i].i1 % ne11; + const int32_t i12 = row_mapping[i].i2; - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; + float * src_row_contiguous = (float *)(src_contiguous + i*nb11); + const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12); - __shared__ int src1_row; - if (threadIdx.x == 0) { - src1_row = atomicAdd(cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - __syncthreads(); - - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); - float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); - - for (int i = threadIdx.x; i < ne10; i += blockDim.x) { - src1_row_contiguous[i] = src1_row_original[i]; + for (int j = threadIdx.x; j < ne10; j += blockDim.x) { + src_row_contiguous[j] = src_row_original[j]; } } @@ -1991,6 +1977,53 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin } } +// MIT licensed. Copyright (C) 2025 Iwan Kawrakow +// https://github.com/ikawrakow/ik_llama.cpp/pull/283 +static inline void prepare_row_mappings(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids, + const ggml_tensor * ids, std::vector& moe_counts, std::vector& cum_moe_counts, + ggml_cuda_pool_alloc& dev_row_mapping) { + + GGML_ASSERT(moe_counts.empty() && cum_moe_counts.empty()); + + auto stream = ctx.stream(); + + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + std::vector rmapping(ids->ne[1]*n_ids); + moe_counts.resize(n_as, 0); + cum_moe_counts.resize(n_as + 1); + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i]; + } + } + cum_moe_counts[0] = 0; + for (int64_t i = 0; i < n_as; ++i) { + cum_moe_counts[i+1] = cum_moe_counts[i] + moe_counts[i]; + } + + dev_row_mapping.alloc(cum_moe_counts[n_as]); + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + if (row_id_i >= 0 && row_id_i < n_as) { + rmapping[cum_moe_counts[row_id_i]++] = {(int)id, (int)iid1}; + } + } + } + + for (int64_t i = 0; i < n_as; ++i) cum_moe_counts[i] -= moe_counts[i]; + + CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -2005,11 +2038,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; - std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; @@ -2035,6 +2063,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.nb[3] = nb1; if (ne12 == 1) { + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { for (int64_t id = 0; id < n_ids; id++) { const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); @@ -2055,6 +2087,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } } else { + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); + std::vector moe_counts, cum_moe_counts; + prepare_row_mappings(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); + ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_cuda_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); @@ -2062,39 +2098,19 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.data = dst_contiguous.get(); for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - - if (row_id_i != i02) { - continue; - } - - num_src1_rows++; - } - } + const int64_t num_src1_rows = moe_counts[i02]; if (num_src1_rows == 0) { continue; } - ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); + const size_t mapping_offset = cum_moe_counts[i02]; { dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(ids->ne[1], n_ids); - k_copy_src1_to_contiguous<<>>( - src1_original, src1_contiguous.get(), - dev_cur_src1_row.get(), dev_row_mapping.get(), - ids_dev, i02, ids->nb[1], ids->nb[0], - ne11, ne10, - nb11, nb12); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); CUDA_CHECK(cudaGetLastError()); } @@ -2120,7 +2136,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dim3 grid_dims(num_src1_rows); k_copy_dst_from_contiguous<<>>( dst_original, dst_contiguous.get(), - dev_row_mapping.get(), + dev_row_mapping.get() + mapping_offset, ne0, nb1, nb2); CUDA_CHECK(cudaGetLastError());