@@ -1024,40 +1024,16 @@ def forward(self, hidden_states: torch.Tensor,
1024
1024
return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
1025
1025
self .layer_name )
1026
1026
1027
+
1027
1028
def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
1028
1029
full_router_logits : torch .Tensor ):
1029
1030
1030
- ctx = get_forward_context ()
1031
-
1032
- max_tokens_across_dp = ctx .dp_metadata .max_tokens_across_dp
1033
- #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu
1034
- num_tokens_across_dp = ctx .dp_metadata .num_tokens_across_dp
1035
-
1036
- #In this function we define two ranges:
1037
- # 1. chunk_range - The current iteration of the loops's range over the
1038
- # DP world tokens
1039
- # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP
1040
- # rank owns.
1041
-
1042
- moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self .dp_size
1043
-
1044
- num_tokens_remaining_across_dp = num_tokens_across_dp
1045
- chunk_start = 0
1046
- chunk_end = min (moe_dp_chunk_size_per_rank ,
1047
- full_hidden_states .shape [0 ])
1048
1031
full_final_hidden_states = torch .empty_like (full_hidden_states )
1049
1032
1050
- assert full_hidden_states .shape [0 ] == full_router_logits .shape [0 ]
1051
-
1052
- for iter in range (0 , max_tokens_across_dp , moe_dp_chunk_size_per_rank ):
1033
+ def process_chunk (chunk_start , chunk_end , skip_result_store = False ):
1053
1034
hidden_states = full_hidden_states [chunk_start :chunk_end , :]
1054
1035
router_logits = full_router_logits [chunk_start :chunk_end , :]
1055
1036
1056
- cu_tokens_across_dp_this_iter = torch .cumsum (
1057
- num_tokens_remaining_across_dp .clamp (
1058
- max = moe_dp_chunk_size_per_rank ),
1059
- dim = 0 )
1060
-
1061
1037
# TODO: still may be needed for non-pplx, put into dispatcher class.
1062
1038
if False :
1063
1039
hidden_states = self .naive_multicast (
@@ -1103,30 +1079,22 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
1103
1079
final_hidden_states = tensor_model_parallel_all_reduce (
1104
1080
final_hidden_states )
1105
1081
1106
- full_final_hidden_states [chunk_start :chunk_end , :].copy_ (
1107
- final_hidden_states )
1108
-
1109
- # Update bounds
1110
- num_tokens_remaining_across_dp = torch .clamp (
1111
- num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank ,
1112
- min = 0 )
1082
+ if not skip_result_store :
1083
+ full_final_hidden_states [chunk_start :chunk_end , :].copy_ (
1084
+ final_hidden_states )
1113
1085
1114
- # HACK FIX
1115
- if num_tokens_remaining_across_dp .sum () == 0 :
1116
- break
1086
+ max_tokens_across_dp = get_forward_context ().dp_metadata .max_tokens_across_dp
1087
+ moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self .dp_size
1117
1088
1118
- def update_chunk_bound (x : int ):
1119
- return min (x + moe_dp_chunk_size_per_rank ,
1120
- full_hidden_states .shape [0 ])
1089
+ num_tokens = full_hidden_states .size (0 )
1090
+ for chunk_start_ in range (0 , max_tokens_across_dp , moe_dp_chunk_size_per_rank ):
1091
+ chunk_start = chunk_start_
1092
+ chunk_end = min (chunk_start + moe_dp_chunk_size_per_rank , max_tokens_across_dp )
1093
+ # clamp start and end
1094
+ chunk_start = min (chunk_start , num_tokens - 1 )
1095
+ chunk_end = min (chunk_end , num_tokens )
1121
1096
1122
- #chunk_start = update_chunk_bound(chunk_start)
1123
- #chunk_end = update_chunk_bound(chunk_end)
1124
- if chunk_end == full_hidden_states .shape [0 ]:
1125
- # simply redo computation
1126
- pass
1127
- else :
1128
- chunk_start = update_chunk_bound (chunk_start )
1129
- chunk_end = update_chunk_bound (chunk_end )
1097
+ process_chunk (chunk_start , chunk_end , skip_result_store = chunk_start_ >= num_tokens )
1130
1098
1131
1099
return full_final_hidden_states
1132
1100
0 commit comments