11import math
22import os
33import random
4- from typing import Literal , TypedDict , cast
4+ from pathlib import Path
5+ from typing import Literal , cast
56
67import numpy as np
78import ray
1415from xtuner .v1 .train .trainer import LoadCheckpointConfig
1516from xtuner .v1 .utils import get_logger , ray_method
1617
17- from .worker import TrainingWorker , WorkerLogItem
18-
1918
2019TRAIN_RAY_GET_TIMEOUT = os .getenv ("XTUNER_TRAIN_RAY_GET_TIMEOUT" , 5 * 3600 ) # default 5 hours
2120
22-
23- class ColateItem (TypedDict ):
24- seq_ctx : SequenceContext
25- shifted_labels : torch .Tensor
26- advantage : float
27- rollout_logprobs : torch .Tensor | None
21+ from .worker import TrainingWorker , WorkerInputItem , WorkerLogItem
2822
2923
3024class RawTrainingController :
@@ -36,6 +30,17 @@ def __init__(self, workers: list[TrainingWorker]) -> None:
3630 self .workers [0 ].get_data_replicate_size .remote (),
3731 ]
3832 self .model_cfg , self .worker_cfg , self .data_replicate_size = ray .get (refs )
33+ log_dir = self .worker_cfg .log_dir
34+ self .log_dir = None
35+ if log_dir is not None :
36+ self .log_dir = Path (log_dir ) if isinstance (log_dir , str ) else log_dir
37+ self .logger = get_logger (log_dir = self .log_dir , tag = "TrainingController" )
38+ else :
39+ self .logger = get_logger ()
40+ self .is_qwen3_vl = False
41+ self .has_rollout_routed_experts = False
42+ self .has_rollout_logprobs = False
43+ self .n_routed_experts = None
3944
4045 # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
4146 def _get_pack_infos (self , dataset , num_tokens , target , random = None ):
@@ -100,7 +105,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg):
100105 pad_len = pack_max_length - total_len
101106 seq_ctx_list = [data_batches [i ]["seq_ctx" ] for i in indices ]
102107 label_list = [data_batches [i ]["shifted_labels" ] for i in indices ]
103- advantage_list = [data_batches [i ]["advantage " ] for i in indices ]
108+ advantage_list = [data_batches [i ]["advantages " ] for i in indices ]
104109
105110 rollout_logprobs_list = None
106111 if "rollout_logprobs" in data_batches [0 ] and data_batches [0 ]["rollout_logprobs" ] is not None :
@@ -177,10 +182,10 @@ def _grouped_by_max_length(self, packed_data_batches):
177182 # 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。
178183 return sorted (packed_data_batches , key = lambda x : x ["seq_ctx" ].max_length_q , reverse = True )
179184
180- def _balance_split_batch (self , data_batches , partition_size ):
185+ def _balance_split_batch (self , data_batches : list [ WorkerInputItem ] , partition_size ) -> list [ list [ WorkerInputItem ]] :
181186 """Reorder the data on single controller such that each dp rank gets
182187 similar total tokens."""
183- global_seqlen_lst = [data ["seq_ctx" ].input_ids .numel () for data in data_batches ]
188+ global_seqlen_lst = [data ["seq_ctx" ].input_ids .numel () for data in data_batches ] # type: ignore[union-attr]
184189 global_partition_lst = get_seqlen_balanced_partitions (
185190 global_seqlen_lst , k_partitions = partition_size , equal_size = True
186191 )
@@ -193,16 +198,12 @@ def _balance_split_batch(self, data_batches, partition_size):
193198 get_logger ().info (f"Balanced split into { partition_size } partitions with tokens: { tokens_in_partition } " )
194199 return balanced_batches
195200
196- def _create_padding_sample (
201+ def _create_padding_item (
197202 self ,
198203 pad_len : int ,
199204 pack_max_length : int ,
200- is_qwen3_vl : bool = False ,
201- has_rollout_routed_experts : bool = False ,
202- has_rollout_logprobs : bool = True ,
203- n_routed_experts : int | None = None ,
204205 split_size : int = 1024 ,
205- ):
206+ ) -> WorkerInputItem :
206207 # padding input_ids
207208 pad_tokens = tuple (
208209 torch .zeros (1 , split_size , dtype = torch .long , device = "cpu" ) for _ in range (pad_len // split_size )
@@ -214,7 +215,7 @@ def _create_padding_sample(
214215 pad_seq_ctx .num_padding = pad_len
215216
216217 # padding mm positions_ids
217- if is_qwen3_vl :
218+ if self . is_qwen3_vl :
218219 _position_ids_list = []
219220 for pad_token in pad_tokens :
220221 _position_ids = torch .arange (pad_token .size (- 1 )).view (1 , 1 , - 1 ).expand (3 , 1 , - 1 )
@@ -224,17 +225,17 @@ def _create_padding_sample(
224225 pad_seq_ctx .position_ids = position_ids
225226
226227 # padding rollout routed experts
227- if has_rollout_routed_experts :
228- assert n_routed_experts , "n_routed_experts must be provided when has_rollout_routed_experts is True"
228+ if self . has_rollout_routed_experts :
229+ assert self . n_routed_experts , "n_routed_experts must be provided when has_rollout_routed_experts is True"
229230 if pad_len == pack_max_length :
230231 pad_rand_index = torch .randint (
231232 low = 0 , high = 1 , size = (1 , 1 , 1 )
232233 ) # add dummy data, true data will be initialized in train worker.fit
233234 else :
234- pad_rand_index = torch .randint (low = 0 , high = n_routed_experts , size = (pad_len , 1 , 1 ))
235+ pad_rand_index = torch .randint (low = 0 , high = self . n_routed_experts , size = (pad_len , 1 , 1 ))
235236 pad_seq_ctx .rollout_routed_experts = pad_rand_index
236237
237- pad_labels = torch .full ((1 , pad_len ), - 100 , dtype = torch .long , device = "cpu" )
238+ pad_labels = cast ( torch .LongTensor , torch . full ((1 , pad_len ), - 100 , dtype = torch .int64 , device = "cpu" ) )
238239 pad_advantage_length = pack_max_length if pad_len == pack_max_length else math .ceil (pad_len / 1024 )
239240 pad_advantage = torch .full (
240241 (1 , pad_advantage_length ),
@@ -243,24 +244,27 @@ def _create_padding_sample(
243244 device = "cpu" ,
244245 )
245246 pad_rollout_logprobs = (
246- torch .zeros (1 , pad_len , dtype = torch .float32 , device = "cpu" ) if has_rollout_logprobs else None
247+ torch .zeros (1 , pad_len , dtype = torch .float32 , device = "cpu" ) if self . has_rollout_logprobs else None
247248 )
248249
249- return {
250+ padding_item : WorkerInputItem = {
250251 "seq_ctx" : pad_seq_ctx ,
251252 "shifted_labels" : pad_labels ,
252253 "advantages" : pad_advantage ,
253254 "rollout_logprobs" : pad_rollout_logprobs ,
254255 }
256+ return padding_item
255257
256- def _pack (self , mini_batch , pack_max_length ):
258+ def _rearrange_batch_for_pack (
259+ self , mini_batch : list [WorkerInputItem ], pack_max_length : int
260+ ) -> list [list [WorkerInputItem ]]:
257261 assert len (mini_batch ) > 0 , "mini_batch should not be empty"
258262 seqlen_list = []
259263 for data in mini_batch :
260- assert data ["seq_ctx" ].input_ids .numel () <= pack_max_length , (
261- f"Single sample seq len { data ['seq_ctx' ].input_ids .numel ()} exceeds pack_max_length { pack_max_length } "
264+ assert data ["seq_ctx" ].input_ids .numel () <= pack_max_length , ( # type: ignore[union-attr]
265+ f"Single sample seq len { data ['seq_ctx' ].input_ids .numel ()} exceeds pack_max_length { pack_max_length } " # type: ignore[union-attr]
262266 )
263- seqlen_list .append (data ["seq_ctx" ].input_ids .numel ())
267+ seqlen_list .append (data ["seq_ctx" ].input_ids .numel ()) # type: ignore[union-attr]
264268 total_length = sum (seqlen_list )
265269
266270 if total_length <= pack_max_length :
@@ -277,15 +281,10 @@ def _pack(self, mini_batch, pack_max_length):
277281 packed_mini_batches .append (packed_batch )
278282 return packed_mini_batches
279283
280- def _get_data_batches_properties (self , data_batches : list [ColateItem ]):
284+ def _set_data_batches_properties (self , data_batches : list [WorkerInputItem ]):
281285 """Extract properties from the first element of data_batches."""
282286 if not data_batches :
283- return {
284- "is_qwen3_vl" : False ,
285- "has_rollout_routed_experts" : False ,
286- "has_rollout_logprobs" : False ,
287- "n_routed_experts" : None ,
288- }
287+ return
289288
290289 first_item = data_batches [0 ]
291290 seq_ctx = first_item ["seq_ctx" ]
@@ -300,114 +299,128 @@ def _get_data_batches_properties(self, data_batches: list[ColateItem]):
300299 if isinstance (self .model_cfg , BaseComposeConfig ):
301300 language_cfg = self .model_cfg .text_config
302301
303- return {
304- "is_qwen3_vl" : is_qwen3_vl ,
305- "has_rollout_routed_experts" : has_rollout_routed_experts ,
306- "has_rollout_logprobs" : has_rollout_logprobs ,
307- "n_routed_experts" : language_cfg .n_routed_experts if language_cfg is not None else None ,
302+ self .is_qwen3_vl = is_qwen3_vl
303+ self .has_rollout_routed_experts = has_rollout_routed_experts
304+ self .has_rollout_logprobs = has_rollout_logprobs
305+ self .n_routed_experts = language_cfg .n_routed_experts if language_cfg is not None else None
306+
307+ def _pad_and_pack_batches (self , batch4pack : list [WorkerInputItem ], pack_max_length : int ) -> WorkerInputItem :
308+ seq_ctx_list = [item ["seq_ctx" ] for item in batch4pack ]
309+ label_list = [item ["shifted_labels" ] for item in batch4pack ]
310+ advantage_list = [torch .tensor ([item ["advantages" ]]).float ().unsqueeze (0 ) for item in batch4pack ]
311+ rollout_logprobs_list = [
312+ item ["rollout_logprobs" ] if self .has_rollout_logprobs else None for item in batch4pack
313+ ]
314+ cur_length = 0
315+ for item in batch4pack :
316+ cur_length += item ["seq_ctx" ].input_ids .numel () # type: ignore[union-attr]
317+ padding_len = pack_max_length - cur_length
318+
319+ if padding_len > 0 :
320+ padding_item = self ._create_padding_item (padding_len , pack_max_length )
321+ seq_ctx_list .append (padding_item ["seq_ctx" ])
322+ label_list .append (padding_item ["shifted_labels" ])
323+ advantage_list .append (padding_item ["advantages" ])
324+ rollout_logprobs_list .append (padding_item ["rollout_logprobs" ])
325+
326+ packed_seq_ctx = SequenceContext .pack (seq_ctx_list )
327+ packed_shifted_labels = torch .cat (label_list , dim = 1 ) # type: ignore[arg-type]
328+ packed_shifted_labels = cast (torch .LongTensor , packed_shifted_labels )
329+ cu_seq_lens_q = packed_seq_ctx .cu_seq_lens_q
330+ packed_num_tokens = cu_seq_lens_q [1 :] - cu_seq_lens_q [:- 1 ]
331+ packed_advantages = torch .cat (advantage_list , dim = 1 )
332+ packed_advantages = torch .repeat_interleave (packed_advantages , packed_num_tokens , dim = 1 )
333+ if self .has_rollout_logprobs :
334+ cast_rollout_logprobs_list = [cast (torch .Tensor , item ) for item in rollout_logprobs_list ]
335+ packed_rollout_logprobs = torch .cat (cast_rollout_logprobs_list , dim = 1 )
336+ else :
337+ packed_rollout_logprobs = None
338+
339+ optimizer_step_packs : WorkerInputItem = {
340+ "seq_ctx" : packed_seq_ctx ,
341+ "shifted_labels" : packed_shifted_labels ,
342+ "advantages" : packed_advantages ,
343+ "rollout_logprobs" : packed_rollout_logprobs ,
308344 }
345+ return optimizer_step_packs
346+
347+ def _pad_to_max_packs_across_workes (
348+ self ,
349+ packed_data_batches : list [list [list [WorkerInputItem ]]],
350+ step_idx : int ,
351+ max_packs : int ,
352+ pack_max_length : int ,
353+ ):
354+ for dp_rank in range (len (packed_data_batches )):
355+ num_current_packs = len (packed_data_batches [dp_rank ][step_idx ])
356+ num_padding_packs = max_packs - num_current_packs
357+
358+ if num_padding_packs > 0 :
359+ padding_item = self ._create_padding_item (pack_max_length , pack_max_length )
360+ padding_items = [padding_item for _ in range (num_padding_packs )]
361+ packed_data_batches [dp_rank ][step_idx ].extend (padding_items )
309362
310363 @ray_method
311364 def fit (
312- self , data_batches : list [ ColateItem ], pack_max_length : int , rollout_idx : int , enable_dp_balance : bool = True
313- ):
314- batch_props = self . _get_data_batches_properties ( data_batches )
315- is_qwen3_vl = batch_props [ "is_qwen3_vl" ]
316- has_rollout_routed_experts = batch_props [ "has_rollout_routed_experts" ]
317- has_rollout_logprobs = batch_props [ "has_rollout_logprobs" ]
318- n_routed_experts = batch_props [ "n_routed_experts" ]
365+ self ,
366+ data_batches : list [ WorkerInputItem ],
367+ pack_max_length : int ,
368+ rollout_idx : int ,
369+ enable_dp_balance : bool = True ,
370+ ) -> list [ WorkerLogItem ]:
371+ self . _set_data_batches_properties ( data_batches )
319372
320373 world_size = len (self .workers )
321374 dp_size = world_size // self .data_replicate_size
322375 assert world_size % self .data_replicate_size == 0 , "world_size must be divisible by data_replicate_size"
323376 optimizer_steps = self .worker_cfg .optimizer_steps
324377
378+ batches_per_dp_group : list [list [WorkerInputItem ]]
325379 if enable_dp_balance :
326380 # 按照 dp_size 对数据进行重新分配,保证每个 dp rank 上的 token 数量大致相同
327381 batches_per_dp_group = self ._balance_split_batch (data_batches , dp_size )
328382 else :
329383 batches_per_dp_group = np .array_split (data_batches , dp_size )
330384 tokens_in_partition = []
331385 for batch in batches_per_dp_group :
332- tokens_in_partition .append (sum (data ["seq_ctx" ].input_ids .numel () for data in batch ))
333- get_logger ().info (f"default split into { dp_size } partitions with tokens: { tokens_in_partition } " )
334-
335- packed_data_batches : list [list [list [dict ]]] = [[[] for _ in range (optimizer_steps )] for _ in range (dp_size )]
336- max_packs_per_card = [0 ] * optimizer_steps
386+ dp_group_total_tokens = 0
387+ for data in batch :
388+ dp_group_total_tokens += data ["seq_ctx" ].input_ids .numel () # type: ignore[union-attr]
389+ tokens_in_partition .append (dp_group_total_tokens )
390+ self .logger .info (f"default split into { dp_size } partitions with tokens: { tokens_in_partition } " )
391+
392+ packed_data_batches : list [list [list [WorkerInputItem ]]] = [
393+ [[] for _ in range (optimizer_steps )] for _ in range (dp_size )
394+ ]
395+ max_packs_per_step = [0 ] * optimizer_steps
337396
338397 for dp_rank , dp_worker_data_batches in enumerate (batches_per_dp_group ):
339- # 每个worker 内部按照optimizer_steps将token均分
398+ # 每个worker内部按照optimizer_steps将token均分
340399 if enable_dp_balance :
341400 random .shuffle (dp_worker_data_batches )
342- mini_batch_for_steps = self ._balance_split_batch (dp_worker_data_batches , optimizer_steps )
401+ mini_batch_for_steps : list [list [WorkerInputItem ]] = self ._balance_split_batch (
402+ dp_worker_data_batches , optimizer_steps
403+ )
343404
344405 for step_idx , step_mini_batch in enumerate (mini_batch_for_steps ):
345- # pack
346- pack_mini_batch = self ._pack (step_mini_batch , pack_max_length )
347- if len (pack_mini_batch ) > max_packs_per_card [step_idx ]:
348- max_packs_per_card [step_idx ] = len (pack_mini_batch )
349-
350- for pack in pack_mini_batch :
351- seq_ctx_list = [item ["seq_ctx" ] for item in pack ]
352- label_list = [item ["shifted_labels" ] for item in pack ]
353- advantage_list = [torch .tensor ([item ["advantage" ]]).float ().unsqueeze (0 ) for item in pack ]
354- rollout_logprobs_list = [
355- item ["rollout_logprobs" ] if has_rollout_logprobs else None for item in pack
356- ]
357- padding_len = pack_max_length - sum ([item ["seq_ctx" ].input_ids .numel () for item in pack ])
358- if padding_len > 0 :
359- padding_sample = self ._create_padding_sample (
360- padding_len ,
361- pack_max_length ,
362- is_qwen3_vl = is_qwen3_vl ,
363- has_rollout_routed_experts = has_rollout_routed_experts ,
364- has_rollout_logprobs = has_rollout_logprobs ,
365- n_routed_experts = n_routed_experts ,
366- )
367- seq_ctx_list .append (padding_sample ["seq_ctx" ])
368- label_list .append (padding_sample ["shifted_labels" ])
369- advantage_list .append (padding_sample ["advantages" ])
370- rollout_logprobs_list .append (padding_sample ["rollout_logprobs" ])
371-
372- packed_seq_ctx = SequenceContext .pack (seq_ctx_list )
373- packed_shifted_labels = torch .cat (label_list , dim = 1 )
374- cu_seq_lens_q = packed_seq_ctx .cu_seq_lens_q
375- packed_num_tokens = cu_seq_lens_q [1 :] - cu_seq_lens_q [:- 1 ]
376- packed_advantages = torch .cat (advantage_list , dim = 1 )
377- packed_advantages = torch .repeat_interleave (packed_advantages , packed_num_tokens , dim = 1 )
378- if has_rollout_logprobs :
379- cast_rollout_logprobs_list = [cast (torch .Tensor , item ) for item in rollout_logprobs_list ]
380- packed_rollout_logprobs = torch .cat (cast_rollout_logprobs_list , dim = 1 )
381- else :
382- packed_rollout_logprobs = None
383- packed_data_batches [dp_rank ][step_idx ].append (
384- {
385- "seq_ctx" : packed_seq_ctx ,
386- "shifted_labels" : packed_shifted_labels ,
387- "advantages" : packed_advantages ,
388- "rollout_logprobs" : packed_rollout_logprobs ,
389- }
390- )
406+ # rearrange mini batch to fit into packs of pack_max_length
407+ batch4pack_list : list [list [WorkerInputItem ]] = self ._rearrange_batch_for_pack (
408+ step_mini_batch , pack_max_length
409+ )
410+ if len (batch4pack_list ) > max_packs_per_step [step_idx ]:
411+ max_packs_per_step [step_idx ] = len (batch4pack_list )
391412
392- get_logger ().info (f"Gradient accumulation steps: { max_packs_per_card } " )
393- # padding for each worker to have same number of packs
394- for dp_rank in range (dp_size ):
395- for step_idx in range (optimizer_steps ):
396- max_packs = max_packs_per_card [step_idx ]
397- num_current_packs = len (packed_data_batches [dp_rank ][step_idx ])
398- num_padding_packs = max_packs - num_current_packs
399-
400- if num_padding_packs > 0 :
401- padding_sample = self ._create_padding_sample (
402- pack_max_length ,
403- pack_max_length ,
404- is_qwen3_vl = is_qwen3_vl ,
405- has_rollout_routed_experts = has_rollout_routed_experts ,
406- has_rollout_logprobs = has_rollout_logprobs ,
407- n_routed_experts = n_routed_experts ,
408- )
409- padding_samples = [padding_sample for _ in range (num_padding_packs )]
410- packed_data_batches [dp_rank ][step_idx ].extend (padding_samples )
413+ for batch4pack in batch4pack_list :
414+ # pad and pack batches into a single optimizer step pack
415+ step_pack = self ._pad_and_pack_batches (batch4pack , pack_max_length )
416+ packed_data_batches [dp_rank ][step_idx ].append (step_pack )
417+
418+ self .logger .info (f"Gradient accumulation for each optimizer steps: { max_packs_per_step } " )
419+
420+ # padding for each worker to have same number of packs in each optimizer step
421+ for step_idx in range (optimizer_steps ):
422+ max_packs = max_packs_per_step [step_idx ]
423+ self ._pad_to_max_packs_across_workes (packed_data_batches , step_idx , max_packs , pack_max_length )
411424
412425 handles = []
413426 for worker_idx , worker in enumerate (self .workers ):
0 commit comments