Skip to content

Commit a3fb31b

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 71595c6 + 4146d6a commit a3fb31b

File tree

13 files changed

+481
-34
lines changed

13 files changed

+481
-34
lines changed

.github/workflows/docker.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
# https://github.com/ggml-org/llama.cpp/issues/11888
4141
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
4242
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
43-
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
43+
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
4444
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
4545
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
4646
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "ggml-cuda/mmq.cuh"
2828
#include "ggml-cuda/mmvf.cuh"
2929
#include "ggml-cuda/mmvq.cuh"
30+
#include "ggml-cuda/moe-expert-reduce.cuh"
3031
#include "ggml-cuda/norm.cuh"
3132
#include "ggml-cuda/opt-step-adamw.cuh"
3233
#include "ggml-cuda/opt-step-sgd.cuh"
@@ -3169,6 +3170,31 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31693170
continue;
31703171
}
31713172

3173+
if (node->op == GGML_OP_MUL) {
3174+
int current_node = i + 1;
3175+
int num_views = 0;
3176+
int num_adds = 0;
3177+
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
3178+
num_views++;
3179+
current_node++;
3180+
}
3181+
3182+
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD &&
3183+
num_adds < num_views - 1) {
3184+
num_adds++;
3185+
current_node++;
3186+
}
3187+
3188+
if (num_adds == num_views - 1 && num_views > 0) {
3189+
ggml_tensor * dst_node = cgraph->nodes[current_node - 1];
3190+
if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) {
3191+
ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node);
3192+
i += num_views + num_adds;
3193+
continue;
3194+
}
3195+
}
3196+
}
3197+
31723198
if (node->op == GGML_OP_ADD) {
31733199
int n_fuse = 0;
31743200
ggml_op ops[8];
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#include "moe-expert-reduce.cuh"
2+
3+
// This kernel is a fusion of the expert weight reduce, common in MoE models
4+
5+
template <int n_expert_used_template>
6+
__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts,
7+
const float * __restrict__ weights,
8+
float * __restrict__ dst,
9+
const int n_expert_used,
10+
const int n_cols) {
11+
const int row = blockIdx.x;
12+
const int col = blockIdx.y * blockDim.x + threadIdx.x;
13+
if (col >= n_cols) {
14+
return;
15+
}
16+
17+
experts += row * n_cols * n_expert_used;
18+
weights += row * n_expert_used;
19+
dst += row * n_cols;
20+
21+
float acc = 0.f;
22+
if constexpr (n_expert_used_template == 0) {
23+
for (int expert = 0; expert < n_expert_used; ++expert) {
24+
ggml_cuda_mad(acc, experts[col], weights[expert]);
25+
experts += n_cols;
26+
}
27+
dst[col] = acc;
28+
} else {
29+
#pragma unroll
30+
for (int i = 0; i < n_expert_used_template; ++i) {
31+
ggml_cuda_mad(acc, experts[col], weights[i]);
32+
experts += n_cols;
33+
}
34+
dst[col] = acc;
35+
}
36+
}
37+
38+
static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx,
39+
const float * experts,
40+
const float * weights,
41+
float * dst,
42+
const int n_expert_used,
43+
const int n_cols,
44+
const int n_rows) {
45+
const int block_size = 32;
46+
47+
const int n_blocks_x = n_rows;
48+
const int n_blocks_y = (n_cols + block_size - 1) / block_size;
49+
50+
dim3 block_dims(block_size);
51+
dim3 grid_dims(n_blocks_x, n_blocks_y);
52+
53+
cudaStream_t stream = ctx.stream();
54+
switch (n_expert_used) {
55+
case 1:
56+
moe_expert_reduce_cuda<1>
57+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
58+
break;
59+
case 2:
60+
moe_expert_reduce_cuda<2>
61+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
62+
break;
63+
case 4:
64+
moe_expert_reduce_cuda<4>
65+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
66+
break;
67+
case 6:
68+
moe_expert_reduce_cuda<6>
69+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
70+
break;
71+
case 8:
72+
moe_expert_reduce_cuda<8>
73+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
74+
break;
75+
case 16:
76+
moe_expert_reduce_cuda<16>
77+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
78+
break;
79+
case 32:
80+
moe_expert_reduce_cuda<32>
81+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
82+
break;
83+
case 64:
84+
moe_expert_reduce_cuda<64>
85+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
86+
break;
87+
case 128:
88+
moe_expert_reduce_cuda<128>
89+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
90+
break;
91+
default:
92+
moe_expert_reduce_cuda<0>
93+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
94+
break;
95+
}
96+
}
97+
98+
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) {
99+
const ggml_tensor * mul = cgraph->nodes[start_index];
100+
101+
if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) {
102+
return false;
103+
}
104+
105+
int current_node = start_index + 1;
106+
size_t current_offset = 0;
107+
108+
std::vector<const ggml_tensor *> view_nodes;
109+
//check if all are views of the expert in increasing order
110+
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
111+
const ggml_tensor * node = cgraph->nodes[current_node];
112+
if (node->view_src != mul) {
113+
return false;
114+
}
115+
if (node->view_offs < current_offset) {
116+
return false;
117+
}
118+
current_offset = node->view_offs;
119+
current_node++;
120+
view_nodes.push_back(node);
121+
}
122+
123+
//check if all the adds are in increasing order
124+
const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0];
125+
int num_adds = 0;
126+
int num_views = view_nodes.size();
127+
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) {
128+
const ggml_tensor * add_node = cgraph->nodes[current_node];
129+
130+
bool is_first_op_ok = num_views > num_adds ? add_node->src[0] == prev_add_src : false;
131+
bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false;
132+
133+
if (!is_first_op_ok || !is_second_op_ok) {
134+
return false;
135+
}
136+
prev_add_src = add_node;
137+
138+
num_adds++;
139+
current_node++;
140+
}
141+
142+
if (num_views != num_adds + 1) {
143+
return false;
144+
}
145+
146+
return true;
147+
}
148+
149+
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
150+
const ggml_tensor * experts,
151+
const ggml_tensor * weights,
152+
ggml_tensor * dst) {
153+
const int n_rows = experts->ne[2];
154+
const int n_expert_used = experts->ne[1];
155+
const int n_cols = experts->ne[0];
156+
157+
GGML_ASSERT(experts->type == GGML_TYPE_F32);
158+
GGML_ASSERT(weights->type == GGML_TYPE_F32);
159+
GGML_ASSERT(ggml_is_contiguous(experts));
160+
GGML_ASSERT(ggml_is_contiguous(weights));
161+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
162+
163+
const float * experts_d = (const float *) experts->data;
164+
const float * weights_d = (const float *) weights->data;
165+
float * dst_d = (float *) dst->data;
166+
167+
launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows);
168+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include "common.cuh"
2+
#include "ggml.h"
3+
4+
#include <initializer_list>
5+
6+
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
7+
const ggml_tensor * experts,
8+
const ggml_tensor * weights,
9+
ggml_tensor * dst);
10+
11+
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index);

0 commit comments

Comments
 (0)