@@ -622,8 +622,7 @@ def __init__(
622
622
assert quant_method is not None
623
623
self .quant_method = quant_method
624
624
625
- dispatch_combine = self ._construct_dispatch_combine (
626
- moe , quant_config )
625
+ dispatch_combine = self ._construct_dispatch_combine (moe , quant_config )
627
626
628
627
success = self .quant_method .set_dispatch_combine (dispatch_combine )
629
628
@@ -1029,13 +1028,12 @@ def forward(self, hidden_states: torch.Tensor,
1029
1028
return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
1030
1029
self .layer_name )
1031
1030
1032
-
1033
1031
def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
1034
1032
full_router_logits : torch .Tensor ):
1035
1033
1036
1034
full_final_hidden_states = torch .empty_like (full_hidden_states )
1037
1035
1038
- def process_chunk (chunk_start , chunk_end , skip_result_store = False ):
1036
+ def process_chunk (chunk_start , chunk_end , skip_result_store = False ):
1039
1037
hidden_states = full_hidden_states [chunk_start :chunk_end , :]
1040
1038
router_logits = full_router_logits [chunk_start :chunk_end , :]
1041
1039
@@ -1088,18 +1086,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store = False):
1088
1086
full_final_hidden_states [chunk_start :chunk_end , :].copy_ (
1089
1087
final_hidden_states )
1090
1088
1091
- max_tokens_across_dp = get_forward_context ().dp_metadata .max_tokens_across_dp
1089
+ max_tokens_across_dp = get_forward_context (
1090
+ ).dp_metadata .max_tokens_across_dp
1092
1091
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self .dp_size
1093
1092
1094
1093
num_tokens = full_hidden_states .size (0 )
1095
- for chunk_start_ in range (0 , max_tokens_across_dp , moe_dp_chunk_size_per_rank ):
1096
- chunk_start = chunk_start_
1097
- chunk_end = min (chunk_start + moe_dp_chunk_size_per_rank , max_tokens_across_dp )
1094
+ for chunk_start_ in range (0 , max_tokens_across_dp ,
1095
+ moe_dp_chunk_size_per_rank ):
1096
+ chunk_start = chunk_start_
1097
+ chunk_end = min (chunk_start + moe_dp_chunk_size_per_rank ,
1098
+ max_tokens_across_dp )
1098
1099
# clamp start and end
1099
1100
chunk_start = min (chunk_start , num_tokens - 1 )
1100
1101
chunk_end = min (chunk_end , num_tokens )
1101
1102
1102
- process_chunk (chunk_start , chunk_end , skip_result_store = chunk_start_ >= num_tokens )
1103
+ process_chunk (chunk_start ,
1104
+ chunk_end ,
1105
+ skip_result_store = chunk_start_ >= num_tokens )
1103
1106
1104
1107
return full_final_hidden_states
1105
1108
0 commit comments