@@ -622,8 +622,7 @@ def __init__(
622
622
assert quant_method is not None
623
623
self .quant_method = quant_method
624
624
625
- dispatch_combine = self ._construct_dispatch_combine (
626
- moe , quant_config )
625
+ dispatch_combine = self ._construct_dispatch_combine (moe , quant_config )
627
626
628
627
success = self .quant_method .set_dispatch_combine (dispatch_combine )
629
628
@@ -1030,13 +1029,12 @@ def forward(self, hidden_states: torch.Tensor,
1030
1029
return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
1031
1030
self .layer_name )
1032
1031
1033
-
1034
1032
def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
1035
1033
full_router_logits : torch .Tensor ):
1036
1034
1037
1035
full_final_hidden_states = torch .empty_like (full_hidden_states )
1038
1036
1039
- def process_chunk (chunk_start , chunk_end , skip_result_store = False ):
1037
+ def process_chunk (chunk_start , chunk_end , skip_result_store = False ):
1040
1038
hidden_states = full_hidden_states [chunk_start :chunk_end , :]
1041
1039
router_logits = full_router_logits [chunk_start :chunk_end , :]
1042
1040
@@ -1089,18 +1087,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store = False):
1089
1087
full_final_hidden_states [chunk_start :chunk_end , :].copy_ (
1090
1088
final_hidden_states )
1091
1089
1092
- max_tokens_across_dp = get_forward_context ().dp_metadata .max_tokens_across_dp
1090
+ max_tokens_across_dp = get_forward_context (
1091
+ ).dp_metadata .max_tokens_across_dp
1093
1092
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self .dp_size
1094
1093
1095
1094
num_tokens = full_hidden_states .size (0 )
1096
- for chunk_start_ in range (0 , max_tokens_across_dp , moe_dp_chunk_size_per_rank ):
1097
- chunk_start = chunk_start_
1098
- chunk_end = min (chunk_start + moe_dp_chunk_size_per_rank , max_tokens_across_dp )
1095
+ for chunk_start_ in range (0 , max_tokens_across_dp ,
1096
+ moe_dp_chunk_size_per_rank ):
1097
+ chunk_start = chunk_start_
1098
+ chunk_end = min (chunk_start + moe_dp_chunk_size_per_rank ,
1099
+ max_tokens_across_dp )
1099
1100
# clamp start and end
1100
1101
chunk_start = min (chunk_start , num_tokens - 1 )
1101
1102
chunk_end = min (chunk_end , num_tokens )
1102
1103
1103
- process_chunk (chunk_start , chunk_end , skip_result_store = chunk_start_ >= num_tokens )
1104
+ process_chunk (chunk_start ,
1105
+ chunk_end ,
1106
+ skip_result_store = chunk_start_ >= num_tokens )
1104
1107
1105
1108
return full_final_hidden_states
1106
1109
0 commit comments