@@ -1023,40 +1023,16 @@ def forward(self, hidden_states: torch.Tensor,
1023
1023
return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
1024
1024
self .layer_name )
1025
1025
1026
+
1026
1027
def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
1027
1028
full_router_logits : torch .Tensor ):
1028
1029
1029
- ctx = get_forward_context ()
1030
-
1031
- max_tokens_across_dp = ctx .dp_metadata .max_tokens_across_dp
1032
- #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu
1033
- num_tokens_across_dp = ctx .dp_metadata .num_tokens_across_dp
1034
-
1035
- #In this function we define two ranges:
1036
- # 1. chunk_range - The current iteration of the loops's range over the
1037
- # DP world tokens
1038
- # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP
1039
- # rank owns.
1040
-
1041
- moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self .dp_size
1042
-
1043
- num_tokens_remaining_across_dp = num_tokens_across_dp
1044
- chunk_start = 0
1045
- chunk_end = min (moe_dp_chunk_size_per_rank ,
1046
- full_hidden_states .shape [0 ])
1047
1030
full_final_hidden_states = torch .empty_like (full_hidden_states )
1048
1031
1049
- assert full_hidden_states .shape [0 ] == full_router_logits .shape [0 ]
1050
-
1051
- for iter in range (0 , max_tokens_across_dp , moe_dp_chunk_size_per_rank ):
1032
+ def process_chunk (chunk_start , chunk_end , skip_result_store = False ):
1052
1033
hidden_states = full_hidden_states [chunk_start :chunk_end , :]
1053
1034
router_logits = full_router_logits [chunk_start :chunk_end , :]
1054
1035
1055
- cu_tokens_across_dp_this_iter = torch .cumsum (
1056
- num_tokens_remaining_across_dp .clamp (
1057
- max = moe_dp_chunk_size_per_rank ),
1058
- dim = 0 )
1059
-
1060
1036
# TODO: still may be needed for non-pplx, put into dispatcher class.
1061
1037
if False :
1062
1038
hidden_states = self .naive_multicast (
@@ -1102,30 +1078,22 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
1102
1078
final_hidden_states = tensor_model_parallel_all_reduce (
1103
1079
final_hidden_states )
1104
1080
1105
- full_final_hidden_states [chunk_start :chunk_end , :].copy_ (
1106
- final_hidden_states )
1107
-
1108
- # Update bounds
1109
- num_tokens_remaining_across_dp = torch .clamp (
1110
- num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank ,
1111
- min = 0 )
1081
+ if not skip_result_store :
1082
+ full_final_hidden_states [chunk_start :chunk_end , :].copy_ (
1083
+ final_hidden_states )
1112
1084
1113
- # HACK FIX
1114
- if num_tokens_remaining_across_dp .sum () == 0 :
1115
- break
1085
+ max_tokens_across_dp = get_forward_context ().dp_metadata .max_tokens_across_dp
1086
+ moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self .dp_size
1116
1087
1117
- def update_chunk_bound (x : int ):
1118
- return min (x + moe_dp_chunk_size_per_rank ,
1119
- full_hidden_states .shape [0 ])
1088
+ num_tokens = full_hidden_states .size (0 )
1089
+ for chunk_start_ in range (0 , max_tokens_across_dp , moe_dp_chunk_size_per_rank ):
1090
+ chunk_start = chunk_start_
1091
+ chunk_end = min (chunk_start + moe_dp_chunk_size_per_rank , max_tokens_across_dp )
1092
+ # clamp start and end
1093
+ chunk_start = min (chunk_start , num_tokens - 1 )
1094
+ chunk_end = min (chunk_end , num_tokens )
1120
1095
1121
- #chunk_start = update_chunk_bound(chunk_start)
1122
- #chunk_end = update_chunk_bound(chunk_end)
1123
- if chunk_end == full_hidden_states .shape [0 ]:
1124
- # simply redo computation
1125
- pass
1126
- else :
1127
- chunk_start = update_chunk_bound (chunk_start )
1128
- chunk_end = update_chunk_bound (chunk_end )
1096
+ process_chunk (chunk_start , chunk_end , skip_result_store = chunk_start_ >= num_tokens )
1129
1097
1130
1098
return full_final_hidden_states
1131
1099
0 commit comments