Skip to content

Commit 5f087c4

Browse files
guschmueCopilot
andauthored
optimized qmoe code path for 1 token (#27383)
avoids gpu -> cpu copy in qmoe and removes 1 of 6 shaders in qmoe. This improves token generation on gpt-oss-20b by ~15% --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 3db53eb commit 5f087c4

16 files changed

+417
-78
lines changed

onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
// in: router_values [num_tokens, num_experts]
66
// in: expert_tokens [used_by], mapping token idx to original token index
77
// out: output
8-
// uniform: used_by, hidden_size, num_experts, expert_idx
8+
// uniform: hidden_size, num_experts, expert_idx, token_offset
99

1010
$MAIN {
1111
let token_idx = expert_tokens[workgroup_idx];
1212
let step = uniforms.hidden_size / workgroup_size_x;
1313
let wg_offset = local_idx * step;
14-
// token_idx is the offset into hidden state while fc2_outputs is for the chunk and
14+
// token_idx is the offset into hidden state while fc2_outputs is for the chunk so
1515
// we need to substract uniforms.token_offset
1616
let router_value_offset = (token_idx - uniforms.token_offset) * uniforms.num_experts + uniforms.expert_idx;
1717
let router_value = router_values[router_value_offset];
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// in: fc2_outputs [used_by, inter_size]
5+
// in: router_values [num_tokens, num_experts]
6+
// in: indirect_experts
7+
// out: output
8+
// uniform: hidden_size, expert_idx
9+
10+
$MAIN {
11+
let expert_idx = indirect_experts[uniforms.expert_idx];
12+
let steps = uniforms.hidden_size / workgroup_size_x;
13+
let router_value = router_values[expert_idx];
14+
let offset = local_idx * steps;
15+
for (var i = 0u; i < steps; i++) {
16+
let weight = fc2_outputs[offset + i];
17+
output[offset + i] += router_value * weight;
18+
}
19+
}

onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
// MOE gate shader
66
//
77
// called with expert as local_idx and token_idx as workgroup_idx
8-
// in: router_values [num_tokens, num_experts], per expert float we multiply final results with
8+
// in: router_logits [num_tokens, num_experts], per expert float we multiply final results with
9+
// out: topk_values [num_tokens, num_experts], number of tokens assigned to each expert
910
// out: gate_counts [num_experts], number of tokens assigned to each expert
1011
// out: gate_hidden [num_experts, num_tokens], token_idx assigned to each expert
1112
// uniform: rows(num_tokens), cols(num_experts), token_offset
@@ -21,7 +22,7 @@ const MAX_FLOAT: f16 = 65504.0;
2122
const MAX_FLOAT: f32 = 3.4028234663852886e+38;
2223
#endif
2324

24-
var<workgroup> shared_vals: array<hidden_state_element_t, workgroup_size_x>;
25+
var<workgroup> shared_vals: array<router_logits_element_t, workgroup_size_x>;
2526
var<workgroup> shared_idxs: array<u32, workgroup_size_x>;
2627

2728
$MAIN {
@@ -32,14 +33,14 @@ $MAIN {
3233
let cols = uniforms.cols;
3334
let output_base = row * cols;
3435

35-
var max_val: hidden_state_element_t = -MAX_FLOAT;
36+
var max_val: router_logits_element_t = -MAX_FLOAT;
3637
var max_idx: u32 = 0u;
3738

3839
if (global_idx < cols) {
3940
atomicStore(&tokencount_for_expert[global_idx], 0u);
4041
}
4142
if (local_idx < cols) {
42-
max_val = hidden_state[(row + uniforms.token_offset) * cols + local_idx];
43+
max_val = router_logits[(row + uniforms.token_offset) * cols + local_idx];
4344
max_idx = local_idx;
4445
}
4546
shared_vals[local_idx] = max_val;
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
//
5+
// MOE 1 token gate shader
6+
//
7+
// called with expert as local_idx
8+
// input: router_logits
9+
// output: topk_values
10+
// output: indirect_experts
11+
12+
#param is_fp16
13+
#param k
14+
15+
const K: u32 = k;
16+
#if is_fp16
17+
const MAX_FLOAT: f16 = 65504.0;
18+
#else
19+
const MAX_FLOAT: f32 = 3.4028234663852886e+38;
20+
#endif
21+
22+
var<workgroup> shared_vals: array<router_logits_element_t, workgroup_size_x>;
23+
var<workgroup> shared_idxs: array<u32, workgroup_size_x>;
24+
25+
$MAIN {
26+
let row = workgroup_idx;
27+
if (row >= uniforms.rows) {
28+
return;
29+
}
30+
let cols = uniforms.cols;
31+
let output_base = row * cols;
32+
33+
var max_val: router_logits_element_t = -MAX_FLOAT;
34+
var max_idx: u32 = 0u;
35+
36+
if (local_idx < cols) {
37+
max_val = router_logits[row * cols + local_idx];
38+
max_idx = local_idx;
39+
}
40+
shared_vals[local_idx] = max_val;
41+
shared_idxs[local_idx] = max_idx;
42+
topk_values[output_base + local_idx] = topk_values_value_t(0);
43+
workgroupBarrier();
44+
45+
// K is small, use a simple bubble sort
46+
for (var i = 0u; i < workgroup_size_x - 1u; i++) {
47+
for (var j = 0u; j < workgroup_size_x - 1u - i; j++) {
48+
if (local_idx == j && local_idx < cols && (local_idx + 1u) < cols) {
49+
// Compare adjacent elements and swap if needed (descending order)
50+
if (shared_vals[local_idx] < shared_vals[local_idx + 1u]) {
51+
let temp_val = shared_vals[local_idx];
52+
let temp_idx = shared_idxs[local_idx];
53+
shared_vals[local_idx] = shared_vals[local_idx + 1u];
54+
shared_idxs[local_idx] = shared_idxs[local_idx + 1u];
55+
shared_vals[local_idx + 1u] = temp_val;
56+
shared_idxs[local_idx + 1u] = temp_idx;
57+
}
58+
}
59+
workgroupBarrier();
60+
}
61+
}
62+
if (local_idx == 0u) {
63+
// softmax
64+
var sum : f32 = 0.0;
65+
for (var i = 0u; i < K; i++) {
66+
sum += exp(f32(shared_vals[i]));
67+
}
68+
for (var i = 0u; i < K; i++) {
69+
let expert_idx = shared_idxs[i];
70+
topk_values[output_base + expert_idx] = topk_values_value_t(exp(f32(shared_vals[i])) / sum);
71+
indirect_experts[i] = expert_idx;
72+
}
73+
}
74+
} // MAIN

onnxruntime/contrib_ops/webgpu/moe/qmoe.cc

Lines changed: 119 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class GateProgram final : public Program<GateProgram> {
2222
GateProgram(int k, bool is_fp16) : Program<GateProgram>{"QmoeGate"}, k_{k}, is_fp16_{is_fp16} {};
2323

2424
Status GenerateShaderCode(ShaderHelper& shader) const override {
25-
shader.AddInput("hidden_state", ShaderUsage::UseElementTypeAlias);
25+
shader.AddInput("router_logits", ShaderUsage::UseElementTypeAlias);
2626
shader.AddOutput("topk_values");
2727
shader.AddOutput("hiddenstate_for_expert");
2828
shader.AddOutput("tokencount_for_expert");
@@ -42,6 +42,29 @@ class GateProgram final : public Program<GateProgram> {
4242
bool is_fp16_;
4343
};
4444

45+
class Gate1TokenProgram final : public Program<Gate1TokenProgram> {
46+
public:
47+
Gate1TokenProgram(int k, bool is_fp16) : Program<Gate1TokenProgram>{"QmoeGate1Token"}, k_{k}, is_fp16_{is_fp16} {};
48+
49+
Status GenerateShaderCode(ShaderHelper& shader) const override {
50+
shader.AddInput("router_logits", ShaderUsage::UseElementTypeAlias);
51+
shader.AddOutput("topk_values");
52+
shader.AddOutput("indirect_experts");
53+
54+
return WGSL_TEMPLATE_APPLY(shader, "moe/gate_1token.wgsl.template",
55+
WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_),
56+
WGSL_TEMPLATE_PARAMETER(k, k_));
57+
};
58+
59+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
60+
{"rows", ProgramUniformVariableDataType::Uint32},
61+
{"cols", ProgramUniformVariableDataType::Uint32});
62+
63+
private:
64+
int k_;
65+
bool is_fp16_;
66+
};
67+
4568
class HiddenStateGatherProgram final : public Program<HiddenStateGatherProgram> {
4669
public:
4770
HiddenStateGatherProgram() : Program<HiddenStateGatherProgram>{"QmoeHiddenStateGather"} {};
@@ -115,7 +138,6 @@ class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
115138
}
116139

117140
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
118-
{"used_by", ProgramUniformVariableDataType::Uint32},
119141
{"hidden_size", ProgramUniformVariableDataType::Uint32},
120142
{"num_experts", ProgramUniformVariableDataType::Uint32},
121143
{"expert_idx", ProgramUniformVariableDataType::Uint32},
@@ -124,6 +146,26 @@ class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
124146
private:
125147
};
126148

149+
class QMoEFinalMix1TokenProgram final : public Program<QMoEFinalMix1TokenProgram> {
150+
public:
151+
QMoEFinalMix1TokenProgram() : Program<QMoEFinalMix1TokenProgram>{"QMoEFinalMix1TokenProgram"} {}
152+
153+
Status GenerateShaderCode(ShaderHelper& shader) const override {
154+
shader.AddInput("fc2_outputs", ShaderUsage::UseElementTypeAlias);
155+
shader.AddInput("router_values", ShaderUsage::UseElementTypeAlias);
156+
shader.AddInput("indirect_experts", ShaderUsage::UseElementTypeAlias);
157+
shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);
158+
159+
return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix_1token.wgsl.template");
160+
}
161+
162+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
163+
{"hidden_size", ProgramUniformVariableDataType::Uint32},
164+
{"expert_idx", ProgramUniformVariableDataType::Uint32});
165+
166+
private:
167+
};
168+
127169
Status QMoE::ComputeInternal(ComputeContext& context) const {
128170
const Tensor* hidden_state = context.Input<Tensor>(0);
129171
const Tensor* router_logits = context.Input<Tensor>(1);
@@ -168,7 +210,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
168210
}
169211

170212
// process tokens in chunks of max_tokens to put some cap on memory usage
171-
const int max_tokens = 512;
213+
const int max_tokens = 2 * 1024;
172214

173215
const uint32_t num_experts = static_cast<uint32_t>(moe_params.num_experts);
174216
const uint32_t hidden_size = static_cast<uint32_t>(moe_params.hidden_size);
@@ -197,6 +239,78 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
197239
.AddUniformVariables({static_cast<uint32_t>(total_output_size)});
198240
ORT_RETURN_IF_ERROR(context.RunProgram(zero));
199241

242+
if (moe_params.num_rows == 1) {
243+
// Optimized code path for 1 token to avoid gpu -> cpu copy
244+
245+
const int num_tokens = 1;
246+
TensorShape gate_value_shape({num_tokens, num_experts});
247+
TensorShape indirect_experts_shape({k_});
248+
249+
Tensor router_values = context.CreateGPUTensor(dtype, gate_value_shape);
250+
Tensor indirect_experts = context.CreateGPUTensor(dtype_uint32, indirect_experts_shape);
251+
252+
Gate1TokenProgram gate{k_, is_fp16};
253+
gate
254+
.AddInputs({{router_logits, ProgramTensorMetadataDependency::Type}})
255+
.AddOutput({&router_values, ProgramTensorMetadataDependency::None})
256+
.AddOutput({&indirect_experts, ProgramTensorMetadataDependency::None})
257+
.SetWorkgroupSize(num_experts)
258+
.SetDispatchGroupSize(static_cast<uint32_t>(num_tokens))
259+
.AddUniformVariables({static_cast<uint32_t>(num_tokens), num_experts})
260+
.CacheHint(k_, is_fp16 ? "fp16" : "fp32");
261+
262+
ORT_RETURN_IF_ERROR(context.RunProgram(gate));
263+
264+
for (uint32_t expert_idx = 0; expert_idx < static_cast<uint32_t>(k_); expert_idx++) {
265+
TensorShape fc1_output_shape({num_tokens, fc1_output_size});
266+
Tensor fc1_outputs = context.CreateGPUTensor(dtype, fc1_output_shape);
267+
TensorShape fc1_activated_shape({num_tokens, moe_params.inter_size});
268+
Tensor fc1_activated = context.CreateGPUTensor(dtype, fc1_activated_shape);
269+
TensorShape fc2_output_shape({num_tokens, N_fc2});
270+
Tensor fc2_outputs = context.CreateGPUTensor(dtype, fc2_output_shape);
271+
272+
status = ApplyMatMulNBits(hidden_state, fc1_experts_weights, fc1_scales, nullptr, fc1_experts_bias_optional,
273+
K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
274+
&fc1_outputs, expert_idx, &indirect_experts);
275+
ORT_RETURN_IF_ERROR(status);
276+
277+
if (is_swiglu) {
278+
SwigLuProgram swiglu;
279+
swiglu
280+
.AddInputs({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2}})
281+
.AddOutput({&fc1_activated, ProgramTensorMetadataDependency::None})
282+
.SetWorkgroupSize(128)
283+
.SetDispatchGroupSize(((num_tokens * static_cast<uint32_t>(moe_params.inter_size)) + 127) / 128)
284+
.AddUniformVariables({static_cast<uint32_t>(num_tokens),
285+
static_cast<uint32_t>(moe_params.inter_size),
286+
activation_alpha_,
287+
activation_beta_,
288+
swiglu_limit_});
289+
ORT_RETURN_IF_ERROR(context.RunProgram(swiglu));
290+
} else {
291+
ORT_THROW("only swiglu is supported for WebGPU.");
292+
}
293+
294+
status = ApplyMatMulNBits(&fc1_activated, fc2_experts_weights, fc2_scales, nullptr, fc2_experts_bias_optional,
295+
K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
296+
&fc2_outputs, expert_idx, &indirect_experts);
297+
ORT_RETURN_IF_ERROR(status);
298+
299+
QMoEFinalMix1TokenProgram final_mix;
300+
final_mix
301+
.AddInputs({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
302+
.AddInputs({{&router_values, ProgramTensorMetadataDependency::Type}})
303+
.AddInputs({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
304+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
305+
.SetDispatchGroupSize(1)
306+
.AddUniformVariables({hidden_size, expert_idx});
307+
308+
ORT_RETURN_IF_ERROR(context.RunProgram(final_mix));
309+
}
310+
return Status::OK();
311+
}
312+
313+
// path for num_tokens > 1
200314
// process tokens in chunks of max_tokens to put some cap on memory usage
201315
for (int token_offset = 0; token_offset < moe_params.num_rows; token_offset += max_tokens) {
202316
//
@@ -226,9 +340,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
226340
.AddOutput({&gate_counts, ProgramTensorMetadataDependency::None, ProgramOutput::Atomic})
227341
.SetWorkgroupSize(num_experts)
228342
.SetDispatchGroupSize(static_cast<uint32_t>(num_tokens))
229-
.AddUniformVariables({static_cast<uint32_t>(num_tokens),
230-
num_experts,
231-
static_cast<uint32_t>(token_offset)})
343+
.AddUniformVariables({static_cast<uint32_t>(num_tokens), num_experts, static_cast<uint32_t>(token_offset)})
232344
.CacheHint(k_, is_fp16 ? "fp16" : "fp32");
233345

234346
ORT_RETURN_IF_ERROR(context.RunProgram(gate));
@@ -318,8 +430,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
318430
.AddInputs({{&expert_tokens, ProgramTensorMetadataDependency::Type}})
319431
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
320432
.SetDispatchGroupSize(used_by)
321-
.AddUniformVariables({used_by,
322-
hidden_size,
433+
.AddUniformVariables({hidden_size,
323434
num_experts,
324435
expert_idx,
325436
static_cast<uint32_t>(token_offset)});

0 commit comments

Comments
 (0)