@@ -461,123 +461,114 @@ def apply(
461461 dispatch_output : StandardDispatchOutput ,
462462 ) -> CombineInput :
463463
464- from sglang .srt .layers .moe .cutlass_moe import cutlass_moe_fp4
465464 from sglang .srt .layers .moe .token_dispatcher import StandardCombineInput
466465
467466 x = dispatch_output .hidden_states
468467 topk_output = dispatch_output .topk_output
469- topk_weights , topk_ids = topk_output .topk_weights , topk_output .topk_ids
470-
471- output = cutlass_moe_fp4 (
472- a = x ,
473- a1_gscale = layer .w13_input_scale_quant ,
474- w1_fp4 = layer .w13_weight ,
475- w1_blockscale = layer .w13_weight_scale ,
476- w1_alphas = layer .g1_alphas ,
477- a2_gscale = layer .w2_input_scale_quant ,
478- w2_fp4 = layer .w2_weight ,
479- w2_blockscale = layer .w2_weight_scale ,
480- w2_alphas = layer .g2_alphas ,
481- topk_weights = topk_weights ,
482- topk_ids = topk_ids ,
483- params = layer .cutlass_moe_params ,
484- apply_router_weight_on_input = self .moe_runner_config .apply_router_weight_on_input ,
485- ).to (x .dtype )
486468
487- return StandardCombineInput (hidden_states = output )
488-
489- def apply_with_router_logits (
490- self ,
491- layer : torch .nn .Module ,
492- dispatch_output : StandardDispatchOutput ,
493- ) -> torch .Tensor :
494- assert self .use_flashinfer_trtllm
495-
496- x = dispatch_output .hidden_states
497- topk_output = dispatch_output .topk_output
498-
499- from flashinfer import fp4_quantize , trtllm_fp4_block_scale_moe
469+ if self .use_flashinfer_trtllm :
470+ from flashinfer import fp4_quantize , trtllm_fp4_block_scale_moe
500471
501- from sglang .srt .layers .moe .utils import RoutingMethodType
472+ router_logits = topk_output .router_logits
473+ topk_config = topk_output .topk_config
502474
503- router_logits = topk_output .router_logits
504- topk_config = topk_output .topk_config
475+ # Quantize input hidden states using fp4_quantize
476+ hs_fp4_bytes , hs_sf_bytes = fp4_quantize (
477+ x ,
478+ layer .w13_input_scale_quant ,
479+ self .group_size , # sf_vec_size
480+ False , # use_ue8m0
481+ False , # is_sf_swizzled_layout
482+ )
483+ hs_fp4 = hs_fp4_bytes .reshape (x .shape [0 ], x .shape [1 ] // 2 )
484+ hs_scale = hs_sf_bytes .view (torch .float8_e4m3fn ).reshape (- 1 )
505485
506- # Quantize input hidden states using fp4_quantize
507- hs_fp4_bytes , hs_sf_bytes = fp4_quantize (
508- x ,
509- layer .w13_input_scale_quant ,
510- self .group_size , # sf_vec_size
511- False , # use_ue8m0
512- False , # is_sf_swizzled_layout
513- )
514- hs_fp4 = hs_fp4_bytes .reshape (x .shape [0 ], x .shape [1 ] // 2 )
515- hs_scale = hs_sf_bytes .view (torch .float8_e4m3fn ).reshape (- 1 )
486+ correction_bias = (
487+ None
488+ if topk_config .correction_bias is None
489+ else topk_config .correction_bias .to (x .dtype )
490+ )
516491
517- correction_bias = (
518- None
519- if topk_config .correction_bias is None
520- else topk_config .correction_bias .to (x .dtype )
521- )
492+ assert layer .routing_method_type is not None
522493
523- assert layer .routing_method_type is not None
494+ # DeepSeekV3 style routing requires float32 router logits
495+ if layer .routing_method_type == RoutingMethodType .DeepSeekV3 :
496+ router_logits = router_logits .to (torch .float32 )
524497
525- # DeepSeekV3 style routing requires float32 router logits
526- if layer .routing_method_type == RoutingMethodType .DeepSeekV3 :
527- router_logits = router_logits .to (torch .float32 )
498+ routed_scaling_factor = self .moe_runner_config .routed_scaling_factor
499+ routed_scaling_factor = (
500+ routed_scaling_factor if routed_scaling_factor is not None else 1.0
501+ )
528502
529- routed_scaling_factor = self .moe_runner_config .routed_scaling_factor
530- routed_scaling_factor = (
531- routed_scaling_factor if routed_scaling_factor is not None else 1.0
532- )
503+ with use_symmetric_memory (
504+ get_tp_group (), disabled = not is_allocation_symmetric ()
505+ ):
506+ num_tokens = hs_fp4 .shape [0 ]
507+ hidden_size = (
508+ hs_fp4 .shape [- 1 ] * 2
509+ if hs_fp4 .dtype == torch .uint8
510+ else hs_fp4 .shape [- 1 ]
511+ )
512+ symm_output = torch .empty (
513+ num_tokens , hidden_size , dtype = torch .bfloat16 , device = hs_fp4 .device
514+ )
533515
534- with use_symmetric_memory (
535- get_tp_group (), disabled = not is_allocation_symmetric ()
536- ):
537- num_tokens = hs_fp4 .shape [0 ]
538- hidden_size = (
539- hs_fp4 .shape [- 1 ] * 2
540- if hs_fp4 .dtype == torch .uint8
541- else hs_fp4 .shape [- 1 ]
542- )
543- symm_output = torch .empty (
544- num_tokens , hidden_size , dtype = torch .bfloat16 , device = hs_fp4 .device
545- )
516+ output = trtllm_fp4_block_scale_moe (
517+ routing_logits = router_logits ,
518+ routing_bias = correction_bias ,
519+ hidden_states = hs_fp4 ,
520+ hidden_states_scale = hs_scale ,
521+ gemm1_weights = layer .gemm1_weights_fp4_shuffled ,
522+ gemm1_weights_scale = layer .gemm1_scales_fp4_shuffled .view (
523+ torch .float8_e4m3fn
524+ ),
525+ gemm1_bias = None ,
526+ gemm1_alpha = None ,
527+ gemm1_beta = None ,
528+ gemm1_clamp_limit = None ,
529+ gemm2_weights = layer .gemm2_weights_fp4_shuffled ,
530+ gemm2_weights_scale = layer .gemm2_scales_fp4_shuffled .view (
531+ torch .float8_e4m3fn
532+ ),
533+ gemm2_bias = None ,
534+ output1_scale_scalar = layer .g1_scale_c ,
535+ output1_scale_gate_scalar = layer .g1_alphas ,
536+ output2_scale_scalar = layer .g2_alphas ,
537+ num_experts = layer .num_experts ,
538+ top_k = topk_config .top_k ,
539+ n_group = topk_config .num_expert_group ,
540+ topk_group = topk_config .topk_group ,
541+ intermediate_size = layer .intermediate_size_per_partition ,
542+ local_expert_offset = layer .moe_ep_rank * layer .num_local_experts ,
543+ local_num_experts = layer .num_local_experts ,
544+ routed_scaling_factor = routed_scaling_factor ,
545+ routing_method_type = layer .routing_method_type ,
546+ do_finalize = True ,
547+ tune_max_num_tokens = next_power_of_2 (hs_fp4 .shape [0 ]),
548+ output = symm_output ,
549+ )[0 ]
550+ else :
551+ from sglang .srt .layers .moe .cutlass_moe import cutlass_moe_fp4
552+
553+ topk_weights , topk_ids = topk_output .topk_weights , topk_output .topk_ids
554+
555+ output = cutlass_moe_fp4 (
556+ a = x ,
557+ a1_gscale = layer .w13_input_scale_quant ,
558+ w1_fp4 = layer .w13_weight ,
559+ w1_blockscale = layer .w13_weight_scale ,
560+ w1_alphas = layer .g1_alphas ,
561+ a2_gscale = layer .w2_input_scale_quant ,
562+ w2_fp4 = layer .w2_weight ,
563+ w2_blockscale = layer .w2_weight_scale ,
564+ w2_alphas = layer .g2_alphas ,
565+ topk_weights = topk_weights ,
566+ topk_ids = topk_ids ,
567+ params = layer .cutlass_moe_params ,
568+ apply_router_weight_on_input = self .moe_runner_config .apply_router_weight_on_input ,
569+ ).to (x .dtype )
546570
547- return trtllm_fp4_block_scale_moe (
548- routing_logits = router_logits ,
549- routing_bias = correction_bias ,
550- hidden_states = hs_fp4 ,
551- hidden_states_scale = hs_scale ,
552- gemm1_weights = layer .gemm1_weights_fp4_shuffled ,
553- gemm1_weights_scale = layer .gemm1_scales_fp4_shuffled .view (
554- torch .float8_e4m3fn
555- ),
556- gemm1_bias = None ,
557- gemm1_alpha = None ,
558- gemm1_beta = None ,
559- gemm1_clamp_limit = None ,
560- gemm2_weights = layer .gemm2_weights_fp4_shuffled ,
561- gemm2_weights_scale = layer .gemm2_scales_fp4_shuffled .view (
562- torch .float8_e4m3fn
563- ),
564- gemm2_bias = None ,
565- output1_scale_scalar = layer .g1_scale_c ,
566- output1_scale_gate_scalar = layer .g1_alphas ,
567- output2_scale_scalar = layer .g2_alphas ,
568- num_experts = layer .num_experts ,
569- top_k = topk_config .top_k ,
570- n_group = topk_config .num_expert_group ,
571- topk_group = topk_config .topk_group ,
572- intermediate_size = layer .intermediate_size_per_partition ,
573- local_expert_offset = layer .moe_ep_rank * layer .num_local_experts ,
574- local_num_experts = layer .num_local_experts ,
575- routed_scaling_factor = routed_scaling_factor ,
576- routing_method_type = layer .routing_method_type ,
577- do_finalize = True ,
578- tune_max_num_tokens = next_power_of_2 (hs_fp4 .shape [0 ]),
579- output = symm_output ,
580- )[0 ]
571+ return StandardCombineInput (hidden_states = output )
581572
582573
583574class CompressedTensorsW8A8Fp8MoEMethod (CompressedTensorsMoEMethod ):
0 commit comments