11import math
2- from typing import Literal , TypedDict
2+ from typing import Literal , TypedDict , cast
33
4+ import numpy as np
45import ray
56import torch
67from ray .actor import ActorProxy
78
89from xtuner .v1 .data_proto .sequence_context import SequenceContext
910from xtuner .v1 .model .compose .base import BaseComposeConfig
11+ from xtuner .v1 .rl .utils import get_seqlen_balanced_partitions
1012from xtuner .v1 .train .trainer import LoadCheckpointConfig
11- from xtuner .v1 .utils import ray_method
13+ from xtuner .v1 .utils import get_logger , ray_method
1214
1315from .worker import TrainingWorker
1416
@@ -23,6 +25,9 @@ class ColateItem(TypedDict):
2325class RawTrainingController :
2426 def __init__ (self , workers : list [TrainingWorker ]) -> None :
2527 self .workers = workers
28+ self .model_cfg = ray .get (self .workers [0 ].get_model_cfg .remote ())
29+ self .worker_cfg = ray .get (self .workers [0 ].get_worker_cfg .remote ())
30+ self .data_replicate_size = ray .get (self .workers [0 ].get_data_replicate_size .remote ())
2631
2732 # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
2833 def _get_pack_infos (self , dataset , num_tokens , target , random = None ):
@@ -164,95 +169,236 @@ def _grouped_by_max_length(self, packed_data_batches):
164169 # 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。
165170 return sorted (packed_data_batches , key = lambda x : x ["seq_ctx" ].max_length_q , reverse = True )
166171
167- @ray_method
168- def fit (self , data_batches : list [ColateItem ], pack_max_length : int , rollout_idx : int ):
169- has_rollout_routed_experts = False
170- language_cfg = None
171- if data_batches [0 ]["seq_ctx" ].rollout_routed_experts is not None :
172- model_cfg = ray .get (self .workers [0 ].get_model_cfg .remote ()) # type: ignore[attr-defined]
173- has_rollout_routed_experts = True
174- language_cfg = model_cfg
175- if isinstance (model_cfg , BaseComposeConfig ):
176- language_cfg = model_cfg .text_config
177-
178- packed_data_batches = self ._packing (data_batches , pack_max_length , language_cfg )
179- # packed_data_batches = self._grouped_by_max_length(packed_data_batches)
172+ def _balance_split_batch (self , data_batches , partition_size ):
173+ """Reorder the data on single controller such that each dp rank gets
174+ similar total tokens."""
175+ global_seqlen_lst = [data ["seq_ctx" ].input_ids .numel () for data in data_batches ]
176+ global_partition_lst = get_seqlen_balanced_partitions (
177+ global_seqlen_lst , k_partitions = partition_size , equal_size = True
178+ )
179+ balanced_batches = []
180+ tokens_in_partition = []
181+ for partition in global_partition_lst :
182+ partition_batch = [data_batches [i ] for i in partition ]
183+ tokens_in_partition .append (sum (data ["seq_ctx" ].input_ids .numel () for data in partition_batch ))
184+ balanced_batches .append (partition_batch )
185+ get_logger ().info (f"Balanced split into { partition_size } partitions with tokens: { tokens_in_partition } " )
186+ return balanced_batches
187+
188+ def _create_padding_sample (
189+ self ,
190+ pad_len : int ,
191+ pack_max_length : int ,
192+ is_qwen3_vl : bool = False ,
193+ has_rollout_routed_experts : bool = False ,
194+ has_rollout_logprobs : bool = True ,
195+ n_routed_experts : int = 0 ,
196+ split_size : int = 1024 ,
197+ ):
198+ # padding input_ids
199+ pad_tokens = tuple (
200+ torch .zeros (1 , split_size , dtype = torch .long , device = "cpu" ) for _ in range (pad_len // split_size )
201+ )
202+ if pad_len % split_size > 0 :
203+ pad_tokens = pad_tokens + (torch .zeros (1 , pad_len % split_size , dtype = torch .long , device = "cpu" ),)
204+ pad_tokens = cast (tuple [torch .LongTensor , ...], pad_tokens )
205+ pad_seq_ctx = SequenceContext .from_input_ids (pad_tokens , device = "cpu" )
206+ pad_seq_ctx .num_padding = pad_len
207+
208+ # padding mm positions_ids
209+ if is_qwen3_vl :
210+ _position_ids_list = []
211+ for pad_token in pad_tokens :
212+ _position_ids = torch .arange (pad_token .size (- 1 )).view (1 , 1 , - 1 ).expand (3 , 1 , - 1 )
213+ _position_ids_list .append (_position_ids )
214+ position_ids = torch .cat (_position_ids_list , dim = - 1 )
215+ position_ids = cast (torch .LongTensor , position_ids )
216+ pad_seq_ctx .position_ids = position_ids
217+
218+ # padding rollout routed experts
219+ if has_rollout_routed_experts :
220+ if pad_len == pack_max_length :
221+ pad_rand_index = torch .randint (
222+ low = 0 , high = 1 , size = (1 , 1 , 1 )
223+ ) # add dummy data, true data will be initialized in train worker.fit
224+ else :
225+ pad_rand_index = torch .randint (low = 0 , high = n_routed_experts , size = (pad_len , 1 , 1 ))
226+ pad_seq_ctx .rollout_routed_experts = pad_rand_index
180227
181- # TODO(hha): 这个逻辑不够通用,和模型绑定了
182- is_qwen3_vl = False
183- if len (packed_data_batches [0 ]["seq_ctx" ].position_ids .shape ) == 3 :
184- is_qwen3_vl = True
228+ pad_labels = torch .full ((1 , pad_len ), - 100 , dtype = torch .long , device = "cpu" )
185229
186- # todo: support round up
187- num_packed_data_batches = len (packed_data_batches )
188- data_replicate_size = ray .get (self .workers [0 ].get_data_replicate_size .remote ()) # type: ignore[attr-defined]
189- dp_size = len (self .workers ) // data_replicate_size
190- pad_num = math .ceil (num_packed_data_batches / dp_size ) * dp_size - num_packed_data_batches
191- if pad_num > 0 :
192- # Reduce the attn calculation time by using multiple short sequence packs
193- assert data_batches [0 ]["seq_ctx" ].input_ids is not None
194- pad_tokens = tuple (
195- torch .zeros (1 , 1024 , dtype = data_batches [0 ]["seq_ctx" ].input_ids .dtype , device = "cpu" )
196- for _ in range (pack_max_length // 1024 )
197- )
198- if pack_max_length % 1024 > 0 :
199- assert data_batches [0 ]["seq_ctx" ].input_ids is not None
200- pad_tokens = pad_tokens + (
201- torch .zeros (
202- 1 , pack_max_length % 1024 , dtype = data_batches [0 ]["seq_ctx" ].input_ids .dtype , device = "cpu"
203- ),
204- )
205- pad_seq_ctx = SequenceContext .from_input_ids (pad_tokens , device = "cpu" ) # type: ignore
206- pad_seq_ctx .num_padding = pack_max_length
207- if is_qwen3_vl :
208- _position_ids_list = []
209- for pad_token in pad_tokens :
210- _position_ids = torch .arange (pad_token .size (- 1 )).view (1 , 1 , - 1 ).expand (3 , 1 , - 1 )
211- _position_ids_list .append (_position_ids )
212- pad_seq_ctx .position_ids = torch .cat (_position_ids_list , dim = - 1 ) # type: ignore
213-
214- pad_shifted_labels = torch .full (
230+ if pad_len == pack_max_length :
231+ pad_advantage_tensor = torch .full (
215232 (1 , pack_max_length ),
216233 - 100 ,
217- dtype = packed_data_batches [ 0 ][ "shifted_labels" ]. dtype ,
234+ dtype = torch . float32 ,
218235 device = "cpu" ,
219236 )
220- pad_advantages = torch .full (
221- (1 , pack_max_length ),
222- - 100 ,
223- dtype = packed_data_batches [0 ]["advantages" ].dtype ,
224- device = "cpu" ,
237+ else :
238+ pad_advantage_array = [- 100 ] * math .ceil (pad_len / split_size )
239+ pad_rollout_logprobs = (
240+ torch .zeros (1 , pad_len , dtype = torch .float32 , device = "cpu" ) if has_rollout_logprobs else None
241+ )
242+
243+ return {
244+ "seq_ctx" : pad_seq_ctx ,
245+ "shifted_labels" : pad_labels ,
246+ "advantages" : pad_advantage_tensor if pad_len == pack_max_length else pad_advantage_array ,
247+ "rollout_logprobs" : pad_rollout_logprobs ,
248+ }
249+
250+ def _pack (self , mini_batch , pack_max_length ):
251+ seqlen_list = []
252+ for data in mini_batch :
253+ assert data ["seq_ctx" ].input_ids .numel () <= pack_max_length , (
254+ f"Single sample seq len { data ['seq_ctx' ].input_ids .numel ()} exceeds pack_max_length { pack_max_length } "
225255 )
256+ seqlen_list .append (data ["seq_ctx" ].input_ids .numel ())
257+ total_length = sum (seqlen_list )
226258
227- if has_rollout_routed_experts :
228- pad_rand_index = torch .randint (
229- low = 0 ,
230- high = 1 ,
231- size = (1 , 1 , 1 ), # add dummy data, true data will be initialized in train worker.fit
232- )
233- pad_seq_ctx .rollout_routed_experts = pad_rand_index
259+ if total_length <= pack_max_length :
260+ return [mini_batch ] # No packing needed
234261
235- pad_rollout_logprobs = None
236- if "rollout_logprobs" in packed_data_batches [0 ] and packed_data_batches [0 ]["rollout_logprobs" ] is not None :
237- pad_rollout_logprobs = torch .zeros (
238- 1 , pack_max_length , dtype = packed_data_batches [0 ]["rollout_logprobs" ].dtype , device = "cpu"
239- )
240- pad_data = {
241- "seq_ctx" : pad_seq_ctx ,
242- "shifted_labels" : pad_shifted_labels ,
243- "advantages" : pad_advantages ,
244- "rollout_logprobs" : pad_rollout_logprobs ,
262+ num_packs = math .ceil (total_length / pack_max_length )
263+ partitions_indices = get_seqlen_balanced_partitions (
264+ seqlen_list = seqlen_list , k_partitions = num_packs , equal_size = False
265+ )
266+
267+ packed_mini_batches = []
268+ for partition in partitions_indices :
269+ packed_batch = [mini_batch [i ] for i in partition ]
270+ packed_mini_batches .append (packed_batch )
271+ return packed_mini_batches
272+
273+ def _get_data_batches_properties (self , data_batches : list [ColateItem ]):
274+ """Extract properties from the first element of data_batches."""
275+ if not data_batches :
276+ return {
277+ "is_qwen3_vl" : False ,
278+ "has_rollout_routed_experts" : False ,
279+ "has_rollout_logprobs" : False ,
280+ "n_routed_experts" : None ,
245281 }
246- pad_data_samples = [pad_data for _ in range (pad_num )]
247- packed_data_batches = packed_data_batches + pad_data_samples
248282
249- print (f"len(packed_data_batches): { len (packed_data_batches )} " )
283+ first_item = data_batches [0 ]
284+ seq_ctx = first_item ["seq_ctx" ]
285+
286+ is_qwen3_vl = seq_ctx .position_ids is not None and len (seq_ctx .position_ids .shape ) == 3
287+ has_rollout_logprobs = "rollout_logprobs" in first_item and first_item ["rollout_logprobs" ] is not None
288+ has_rollout_routed_experts = seq_ctx .rollout_routed_experts is not None
289+
290+ model_cfg = ray .get (self .workers [0 ].get_model_cfg .remote ()) # type: ignore[attr-defined]
291+ language_cfg = None
292+ if has_rollout_routed_experts :
293+ language_cfg = model_cfg
294+ if isinstance (model_cfg , BaseComposeConfig ):
295+ language_cfg = model_cfg .text_config
296+
297+ return {
298+ "is_qwen3_vl" : is_qwen3_vl ,
299+ "has_rollout_routed_experts" : has_rollout_routed_experts ,
300+ "has_rollout_logprobs" : has_rollout_logprobs ,
301+ "n_routed_experts" : language_cfg .n_routed_experts if language_cfg is not None else None ,
302+ }
303+
304+ @ray_method
305+ def fit (self , data_batches : list [ColateItem ], pack_max_length : int , rollout_idx : int , enable_balance : bool = True ):
306+ batch_props = self ._get_data_batches_properties (data_batches )
307+ is_qwen3_vl = batch_props ["is_qwen3_vl" ]
308+ has_rollout_routed_experts = batch_props ["has_rollout_routed_experts" ]
309+ has_rollout_logprobs = batch_props ["has_rollout_logprobs" ]
310+ n_routed_experts = batch_props ["n_routed_experts" ]
311+
312+ world_size = len (self .workers )
313+ dp_size = world_size // self .data_replicate_size
314+ assert world_size % self .data_replicate_size == 0 , "world_size must be divisible by data_replicate_size"
315+ optimizer_steps = self .worker_cfg .optimizer_steps
316+
317+ if enable_balance :
318+ batches_per_dp_group = self ._balance_split_batch (data_batches , dp_size )
319+ else :
320+ batches_per_dp_group = np .array_split (data_batches , dp_size )
321+
322+ packed_data_batches : list [list [list [dict ]]] = [[[] for _ in range (optimizer_steps )] for _ in range (dp_size )]
323+ max_packs_per_card = [0 ] * optimizer_steps
324+
325+ for dp_rank , dp_worker_data_batches in enumerate (batches_per_dp_group ):
326+ # 每个worker 内部按照optimizer_steps将token均分
327+ mini_batch_for_steps = self ._balance_split_batch (dp_worker_data_batches , optimizer_steps )
328+
329+ for step_idx , step_mini_batch in enumerate (mini_batch_for_steps ):
330+ # pack
331+ pack_mini_batch = self ._pack (step_mini_batch , pack_max_length )
332+ if len (pack_mini_batch ) > max_packs_per_card [step_idx ]:
333+ max_packs_per_card [step_idx ] = len (pack_mini_batch )
334+
335+ for pack in pack_mini_batch :
336+ seq_ctx_list = [item ["seq_ctx" ] for item in pack ]
337+ label_list = [item ["shifted_labels" ] for item in pack ]
338+ advantage_list = [torch .tensor ([item ["advantage" ]]).float ().unsqueeze (0 ) for item in pack ]
339+ rollout_logprobs_list = [
340+ item ["rollout_logprobs" ] if has_rollout_logprobs else None for item in pack
341+ ]
342+ padding_len = pack_max_length - sum ([item ["seq_ctx" ].input_ids .numel () for item in pack ])
343+ if padding_len > 0 :
344+ padding_sample = self ._create_padding_sample (
345+ padding_len ,
346+ pack_max_length ,
347+ is_qwen3_vl = is_qwen3_vl ,
348+ has_rollout_routed_experts = has_rollout_routed_experts ,
349+ has_rollout_logprobs = has_rollout_logprobs ,
350+ n_routed_experts = n_routed_experts ,
351+ )
352+ seq_ctx_list .append (padding_sample ["seq_ctx" ])
353+ label_list .append (padding_sample ["shifted_labels" ])
354+ advantage_list .extend (padding_sample ["advantages" ])
355+ rollout_logprobs_list .append (padding_sample ["rollout_logprobs" ])
356+
357+ packed_seq_ctx = SequenceContext .pack (seq_ctx_list )
358+ paced_shifted_labels = torch .cat (label_list , dim = 1 )
359+ packed_advantages = torch .tensor (advantage_list ).float ().unsqueeze (0 )
360+ cu_seq_lens_q = packed_seq_ctx .cu_seq_lens_q
361+ packed_num_tokens = cu_seq_lens_q [1 :] - cu_seq_lens_q [:- 1 ]
362+ packed_advantages = torch .repeat_interleave (packed_advantages , packed_num_tokens , dim = 1 )
363+ if has_rollout_logprobs :
364+ cast_rollout_logprobs_list = [cast (torch .Tensor , item ) for item in rollout_logprobs_list ]
365+ packed_rollout_logprobs = torch .cat (cast_rollout_logprobs_list , dim = 1 )
366+ else :
367+ packed_rollout_logprobs = None
368+ packed_data_batches [dp_rank ][step_idx ].append (
369+ {
370+ "seq_ctx" : packed_seq_ctx ,
371+ "shifted_labels" : paced_shifted_labels ,
372+ "advantages" : packed_advantages ,
373+ "rollout_logprobs" : packed_rollout_logprobs ,
374+ }
375+ )
376+
377+ get_logger ().info (f"Gradient accumulation steps: { max_packs_per_card } " )
378+ # padding for each worker to have same number of packs
379+ for dp_rank in range (dp_size ):
380+ for step_idx in range (optimizer_steps ):
381+ max_packs = max_packs_per_card [step_idx ]
382+ num_current_packs = len (packed_data_batches [dp_rank ][step_idx ])
383+ num_padding_packs = max_packs - num_current_packs
384+
385+ if num_padding_packs > 0 :
386+ padding_sample = self ._create_padding_sample (
387+ pack_max_length ,
388+ pack_max_length ,
389+ is_qwen3_vl = is_qwen3_vl ,
390+ has_rollout_routed_experts = has_rollout_routed_experts ,
391+ has_rollout_logprobs = has_rollout_logprobs ,
392+ n_routed_experts = n_routed_experts ,
393+ )
394+ padding_samples = [padding_sample for _ in range (num_padding_packs )]
395+ packed_data_batches [dp_rank ][step_idx ].extend (padding_samples )
250396
251397 handles = []
252398 for worker_idx , worker in enumerate (self .workers ):
253399 handles .append (
254400 worker .fit .remote ( # type: ignore[attr-defined]
255- data_batches = packed_data_batches [( worker_idx // data_replicate_size ) :: dp_size ],
401+ data_batches = packed_data_batches [worker_idx // self . data_replicate_size ],
256402 rollout_idx = rollout_idx ,
257403 )
258404 )
0 commit comments