55
55
fused_moe_pallas = None # type: ignore
56
56
logger = init_logger (__name__ )
57
57
58
- MOE_DP_CHUNK_SIZE = 256
58
+ # Note: this limit is somewhat arbitrary and might be changed later.
59
+ MOE_DP_CHUNK_SIZE = envs .VLLM_FUSED_MOE_CHUNK_SIZE
59
60
60
61
61
62
@dataclass
@@ -435,8 +436,6 @@ def set_prepare_finalize(
435
436
436
437
experts : Optional [FusedMoEPermuteExpertsUnpermute ] = None
437
438
438
- self .using_pplx = False
439
-
440
439
if isinstance (prepare_finalize ,
441
440
(BatchedPrepareAndFinalize , PplxPrepareAndFinalize )):
442
441
logger .debug ("BatchedTritonExperts %s" , self .moe )
@@ -450,8 +449,6 @@ def set_prepare_finalize(
450
449
use_int4_w4a16 = False ,
451
450
block_shape = None ,
452
451
)
453
- self .using_pplx = isinstance (prepare_finalize ,
454
- PplxPrepareAndFinalize )
455
452
else :
456
453
logger .debug ("TritonExperts %s" , self .moe )
457
454
experts = TritonExperts (
@@ -499,7 +496,7 @@ def forward_cuda(
499
496
custom_routing_function = custom_routing_function ,
500
497
scoring_func = scoring_func ,
501
498
e_score_correction_bias = e_score_correction_bias ,
502
- indices_type = torch .uint32 if self .using_pplx else None )
499
+ indices_type = torch .uint32 if self .use_pplx_kernels else None )
503
500
504
501
if self .rocm_aiter_moe_enabled :
505
502
return self .rocm_aiter_fused_experts (
@@ -828,7 +825,8 @@ def __init__(
828
825
hidden_dim = hidden_size ,
829
826
num_local_experts = self .local_num_experts ,
830
827
moe_parallel_config = self .moe_parallel_config ,
831
- in_dtype = params_dtype , # TODO: is this right?
828
+ # TODO (bnell): this needs to be fixed for quantized types.
829
+ in_dtype = params_dtype ,
832
830
)
833
831
834
832
# Note: get_quant_method will look at the layer's local_num_experts
0 commit comments