@@ -22,7 +22,7 @@ class GateProgram final : public Program<GateProgram> {
2222 GateProgram (int k, bool is_fp16) : Program<GateProgram>{" QmoeGate" }, k_{k}, is_fp16_{is_fp16} {};
2323
2424 Status GenerateShaderCode (ShaderHelper& shader) const override {
25- shader.AddInput (" hidden_state " , ShaderUsage::UseElementTypeAlias);
25+ shader.AddInput (" router_logits " , ShaderUsage::UseElementTypeAlias);
2626 shader.AddOutput (" topk_values" );
2727 shader.AddOutput (" hiddenstate_for_expert" );
2828 shader.AddOutput (" tokencount_for_expert" );
@@ -42,6 +42,29 @@ class GateProgram final : public Program<GateProgram> {
4242 bool is_fp16_;
4343};
4444
45+ class Gate1TokenProgram final : public Program<Gate1TokenProgram> {
46+ public:
47+ Gate1TokenProgram (int k, bool is_fp16) : Program<Gate1TokenProgram>{" QmoeGate1Token" }, k_{k}, is_fp16_{is_fp16} {};
48+
49+ Status GenerateShaderCode (ShaderHelper& shader) const override {
50+ shader.AddInput (" router_logits" , ShaderUsage::UseElementTypeAlias);
51+ shader.AddOutput (" topk_values" );
52+ shader.AddOutput (" indirect_experts" );
53+
54+ return WGSL_TEMPLATE_APPLY (shader, " moe/gate_1token.wgsl.template" ,
55+ WGSL_TEMPLATE_PARAMETER (is_fp16, is_fp16_),
56+ WGSL_TEMPLATE_PARAMETER (k, k_));
57+ };
58+
59+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES (
60+ {" rows" , ProgramUniformVariableDataType::Uint32},
61+ {" cols" , ProgramUniformVariableDataType::Uint32});
62+
63+ private:
64+ int k_;
65+ bool is_fp16_;
66+ };
67+
4568class HiddenStateGatherProgram final : public Program<HiddenStateGatherProgram> {
4669 public:
4770 HiddenStateGatherProgram () : Program<HiddenStateGatherProgram>{" QmoeHiddenStateGather" } {};
@@ -115,7 +138,6 @@ class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
115138 }
116139
117140 WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES (
118- {" used_by" , ProgramUniformVariableDataType::Uint32},
119141 {" hidden_size" , ProgramUniformVariableDataType::Uint32},
120142 {" num_experts" , ProgramUniformVariableDataType::Uint32},
121143 {" expert_idx" , ProgramUniformVariableDataType::Uint32},
@@ -124,6 +146,26 @@ class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
124146 private:
125147};
126148
149+ class QMoEFinalMix1TokenProgram final : public Program<QMoEFinalMix1TokenProgram> {
150+ public:
151+ QMoEFinalMix1TokenProgram () : Program<QMoEFinalMix1TokenProgram>{" QMoEFinalMix1TokenProgram" } {}
152+
153+ Status GenerateShaderCode (ShaderHelper& shader) const override {
154+ shader.AddInput (" fc2_outputs" , ShaderUsage::UseElementTypeAlias);
155+ shader.AddInput (" router_values" , ShaderUsage::UseElementTypeAlias);
156+ shader.AddInput (" indirect_experts" , ShaderUsage::UseElementTypeAlias);
157+ shader.AddOutput (" output" , ShaderUsage::UseElementTypeAlias);
158+
159+ return WGSL_TEMPLATE_APPLY (shader, " moe/final_mix_1token.wgsl.template" );
160+ }
161+
162+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES (
163+ {" hidden_size" , ProgramUniformVariableDataType::Uint32},
164+ {" expert_idx" , ProgramUniformVariableDataType::Uint32});
165+
166+ private:
167+ };
168+
127169Status QMoE::ComputeInternal (ComputeContext& context) const {
128170 const Tensor* hidden_state = context.Input <Tensor>(0 );
129171 const Tensor* router_logits = context.Input <Tensor>(1 );
@@ -168,7 +210,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
168210 }
169211
170212 // process tokens in chunks of max_tokens to put some cap on memory usage
171- const int max_tokens = 512 ;
213+ const int max_tokens = 2 * 1024 ;
172214
173215 const uint32_t num_experts = static_cast <uint32_t >(moe_params.num_experts );
174216 const uint32_t hidden_size = static_cast <uint32_t >(moe_params.hidden_size );
@@ -197,6 +239,78 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
197239 .AddUniformVariables ({static_cast <uint32_t >(total_output_size)});
198240 ORT_RETURN_IF_ERROR (context.RunProgram (zero));
199241
242+ if (moe_params.num_rows == 1 ) {
243+ // Optimized code path for 1 token to avoid gpu -> cpu copy
244+
245+ const int num_tokens = 1 ;
246+ TensorShape gate_value_shape ({num_tokens, num_experts});
247+ TensorShape indirect_experts_shape ({k_});
248+
249+ Tensor router_values = context.CreateGPUTensor (dtype, gate_value_shape);
250+ Tensor indirect_experts = context.CreateGPUTensor (dtype_uint32, indirect_experts_shape);
251+
252+ Gate1TokenProgram gate{k_, is_fp16};
253+ gate
254+ .AddInputs ({{router_logits, ProgramTensorMetadataDependency::Type}})
255+ .AddOutput ({&router_values, ProgramTensorMetadataDependency::None})
256+ .AddOutput ({&indirect_experts, ProgramTensorMetadataDependency::None})
257+ .SetWorkgroupSize (num_experts)
258+ .SetDispatchGroupSize (static_cast <uint32_t >(num_tokens))
259+ .AddUniformVariables ({static_cast <uint32_t >(num_tokens), num_experts})
260+ .CacheHint (k_, is_fp16 ? " fp16" : " fp32" );
261+
262+ ORT_RETURN_IF_ERROR (context.RunProgram (gate));
263+
264+ for (uint32_t expert_idx = 0 ; expert_idx < static_cast <uint32_t >(k_); expert_idx++) {
265+ TensorShape fc1_output_shape ({num_tokens, fc1_output_size});
266+ Tensor fc1_outputs = context.CreateGPUTensor (dtype, fc1_output_shape);
267+ TensorShape fc1_activated_shape ({num_tokens, moe_params.inter_size });
268+ Tensor fc1_activated = context.CreateGPUTensor (dtype, fc1_activated_shape);
269+ TensorShape fc2_output_shape ({num_tokens, N_fc2});
270+ Tensor fc2_outputs = context.CreateGPUTensor (dtype, fc2_output_shape);
271+
272+ status = ApplyMatMulNBits (hidden_state, fc1_experts_weights, fc1_scales, nullptr , fc1_experts_bias_optional,
273+ K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
274+ &fc1_outputs, expert_idx, &indirect_experts);
275+ ORT_RETURN_IF_ERROR (status);
276+
277+ if (is_swiglu) {
278+ SwigLuProgram swiglu;
279+ swiglu
280+ .AddInputs ({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2 }})
281+ .AddOutput ({&fc1_activated, ProgramTensorMetadataDependency::None})
282+ .SetWorkgroupSize (128 )
283+ .SetDispatchGroupSize (((num_tokens * static_cast <uint32_t >(moe_params.inter_size )) + 127 ) / 128 )
284+ .AddUniformVariables ({static_cast <uint32_t >(num_tokens),
285+ static_cast <uint32_t >(moe_params.inter_size ),
286+ activation_alpha_,
287+ activation_beta_,
288+ swiglu_limit_});
289+ ORT_RETURN_IF_ERROR (context.RunProgram (swiglu));
290+ } else {
291+ ORT_THROW (" only swiglu is supported for WebGPU." );
292+ }
293+
294+ status = ApplyMatMulNBits (&fc1_activated, fc2_experts_weights, fc2_scales, nullptr , fc2_experts_bias_optional,
295+ K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
296+ &fc2_outputs, expert_idx, &indirect_experts);
297+ ORT_RETURN_IF_ERROR (status);
298+
299+ QMoEFinalMix1TokenProgram final_mix;
300+ final_mix
301+ .AddInputs ({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
302+ .AddInputs ({{&router_values, ProgramTensorMetadataDependency::Type}})
303+ .AddInputs ({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
304+ .AddOutput ({output_tensor, ProgramTensorMetadataDependency::None})
305+ .SetDispatchGroupSize (1 )
306+ .AddUniformVariables ({hidden_size, expert_idx});
307+
308+ ORT_RETURN_IF_ERROR (context.RunProgram (final_mix));
309+ }
310+ return Status::OK ();
311+ }
312+
313+ // path for num_tokens > 1
200314 // process tokens in chunks of max_tokens to put some cap on memory usage
201315 for (int token_offset = 0 ; token_offset < moe_params.num_rows ; token_offset += max_tokens) {
202316 //
@@ -226,9 +340,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
226340 .AddOutput ({&gate_counts, ProgramTensorMetadataDependency::None, ProgramOutput::Atomic})
227341 .SetWorkgroupSize (num_experts)
228342 .SetDispatchGroupSize (static_cast <uint32_t >(num_tokens))
229- .AddUniformVariables ({static_cast <uint32_t >(num_tokens),
230- num_experts,
231- static_cast <uint32_t >(token_offset)})
343+ .AddUniformVariables ({static_cast <uint32_t >(num_tokens), num_experts, static_cast <uint32_t >(token_offset)})
232344 .CacheHint (k_, is_fp16 ? " fp16" : " fp32" );
233345
234346 ORT_RETURN_IF_ERROR (context.RunProgram (gate));
@@ -318,8 +430,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
318430 .AddInputs ({{&expert_tokens, ProgramTensorMetadataDependency::Type}})
319431 .AddOutput ({output_tensor, ProgramTensorMetadataDependency::None})
320432 .SetDispatchGroupSize (used_by)
321- .AddUniformVariables ({used_by,
322- hidden_size,
433+ .AddUniformVariables ({hidden_size,
323434 num_experts,
324435 expert_idx,
325436 static_cast <uint32_t >(token_offset)});
0 commit comments