11import math
22import random
3- from typing import Literal , TypedDict , cast
3+ from pathlib import Path
4+ from typing import Literal , cast
45
56import numpy as np
67import ray
1314from xtuner .v1 .train .trainer import LoadCheckpointConfig
1415from xtuner .v1 .utils import get_logger , ray_method
1516
16- from .worker import TrainingWorker
17-
18-
19- class ColateItem (TypedDict ):
20- seq_ctx : SequenceContext
21- shifted_labels : torch .Tensor
22- advantage : float
23- rollout_logprobs : torch .Tensor | None
17+ from .worker import TrainingWorker , WorkerInputItem
2418
2519
2620class RawTrainingController :
@@ -32,6 +26,17 @@ def __init__(self, workers: list[TrainingWorker]) -> None:
3226 self .workers [0 ].get_data_replicate_size .remote (),
3327 ]
3428 self .model_cfg , self .worker_cfg , self .data_replicate_size = ray .get (refs )
29+ log_dir = self .worker_cfg .log_dir
30+ self .log_dir = None
31+ if log_dir is not None :
32+ self .log_dir = Path (log_dir ) if isinstance (log_dir , str ) else log_dir
33+ self .logger = get_logger (log_dir = self .log_dir , tag = "TrainingController" )
34+ else :
35+ self .logger = get_logger ()
36+ self .is_qwen3_vl = False
37+ self .has_rollout_routed_experts = False
38+ self .has_rollout_logprobs = False
39+ self .n_routed_experts = None
3540
3641 # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
3742 def _get_pack_infos (self , dataset , num_tokens , target , random = None ):
@@ -96,7 +101,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg):
96101 pad_len = pack_max_length - total_len
97102 seq_ctx_list = [data_batches [i ]["seq_ctx" ] for i in indices ]
98103 label_list = [data_batches [i ]["shifted_labels" ] for i in indices ]
99- advantage_list = [data_batches [i ]["advantage " ] for i in indices ]
104+ advantage_list = [data_batches [i ]["advantages " ] for i in indices ]
100105
101106 rollout_logprobs_list = None
102107 if "rollout_logprobs" in data_batches [0 ] and data_batches [0 ]["rollout_logprobs" ] is not None :
@@ -173,10 +178,10 @@ def _grouped_by_max_length(self, packed_data_batches):
173178 # 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。
174179 return sorted (packed_data_batches , key = lambda x : x ["seq_ctx" ].max_length_q , reverse = True )
175180
176- def _balance_split_batch (self , data_batches , partition_size ):
181+ def _balance_split_batch (self , data_batches : list [ WorkerInputItem ] , partition_size ) -> list [ list [ WorkerInputItem ]] :
177182 """Reorder the data on single controller such that each dp rank gets
178183 similar total tokens."""
179- global_seqlen_lst = [data ["seq_ctx" ].input_ids .numel () for data in data_batches ]
184+ global_seqlen_lst = [data ["seq_ctx" ].input_ids .numel () for data in data_batches ] # type: ignore[union-attr]
180185 global_partition_lst = get_seqlen_balanced_partitions (
181186 global_seqlen_lst , k_partitions = partition_size , equal_size = True
182187 )
@@ -189,16 +194,12 @@ def _balance_split_batch(self, data_batches, partition_size):
189194 get_logger ().info (f"Balanced split into { partition_size } partitions with tokens: { tokens_in_partition } " )
190195 return balanced_batches
191196
192- def _create_padding_sample (
197+ def _create_padding_item (
193198 self ,
194199 pad_len : int ,
195200 pack_max_length : int ,
196- is_qwen3_vl : bool = False ,
197- has_rollout_routed_experts : bool = False ,
198- has_rollout_logprobs : bool = True ,
199- n_routed_experts : int | None = None ,
200201 split_size : int = 1024 ,
201- ):
202+ ) -> WorkerInputItem :
202203 # padding input_ids
203204 pad_tokens = tuple (
204205 torch .zeros (1 , split_size , dtype = torch .long , device = "cpu" ) for _ in range (pad_len // split_size )
@@ -210,7 +211,7 @@ def _create_padding_sample(
210211 pad_seq_ctx .num_padding = pad_len
211212
212213 # padding mm positions_ids
213- if is_qwen3_vl :
214+ if self . is_qwen3_vl :
214215 _position_ids_list = []
215216 for pad_token in pad_tokens :
216217 _position_ids = torch .arange (pad_token .size (- 1 )).view (1 , 1 , - 1 ).expand (3 , 1 , - 1 )
@@ -220,17 +221,17 @@ def _create_padding_sample(
220221 pad_seq_ctx .position_ids = position_ids
221222
222223 # padding rollout routed experts
223- if has_rollout_routed_experts :
224- assert n_routed_experts , "n_routed_experts must be provided when has_rollout_routed_experts is True"
224+ if self . has_rollout_routed_experts :
225+ assert self . n_routed_experts , "n_routed_experts must be provided when has_rollout_routed_experts is True"
225226 if pad_len == pack_max_length :
226227 pad_rand_index = torch .randint (
227228 low = 0 , high = 1 , size = (1 , 1 , 1 )
228229 ) # add dummy data, true data will be initialized in train worker.fit
229230 else :
230- pad_rand_index = torch .randint (low = 0 , high = n_routed_experts , size = (pad_len , 1 , 1 ))
231+ pad_rand_index = torch .randint (low = 0 , high = self . n_routed_experts , size = (pad_len , 1 , 1 ))
231232 pad_seq_ctx .rollout_routed_experts = pad_rand_index
232233
233- pad_labels = torch .full ((1 , pad_len ), - 100 , dtype = torch .long , device = "cpu" )
234+ pad_labels = cast ( torch .LongTensor , torch . full ((1 , pad_len ), - 100 , dtype = torch .int64 , device = "cpu" ) )
234235 pad_advantage_length = pack_max_length if pad_len == pack_max_length else math .ceil (pad_len / 1024 )
235236 pad_advantage = torch .full (
236237 (1 , pad_advantage_length ),
@@ -239,24 +240,27 @@ def _create_padding_sample(
239240 device = "cpu" ,
240241 )
241242 pad_rollout_logprobs = (
242- torch .zeros (1 , pad_len , dtype = torch .float32 , device = "cpu" ) if has_rollout_logprobs else None
243+ torch .zeros (1 , pad_len , dtype = torch .float32 , device = "cpu" ) if self . has_rollout_logprobs else None
243244 )
244245
245- return {
246+ padding_item : WorkerInputItem = {
246247 "seq_ctx" : pad_seq_ctx ,
247248 "shifted_labels" : pad_labels ,
248249 "advantages" : pad_advantage ,
249250 "rollout_logprobs" : pad_rollout_logprobs ,
250251 }
252+ return padding_item
251253
252- def _pack (self , mini_batch , pack_max_length ):
254+ def _rearrange_batch_for_pack (
255+ self , mini_batch : list [WorkerInputItem ], pack_max_length : int
256+ ) -> list [list [WorkerInputItem ]]:
253257 assert len (mini_batch ) > 0 , "mini_batch should not be empty"
254258 seqlen_list = []
255259 for data in mini_batch :
256- assert data ["seq_ctx" ].input_ids .numel () <= pack_max_length , (
257- f"Single sample seq len { data ['seq_ctx' ].input_ids .numel ()} exceeds pack_max_length { pack_max_length } "
260+ assert data ["seq_ctx" ].input_ids .numel () <= pack_max_length , ( # type: ignore[union-attr]
261+ f"Single sample seq len { data ['seq_ctx' ].input_ids .numel ()} exceeds pack_max_length { pack_max_length } " # type: ignore[union-attr]
258262 )
259- seqlen_list .append (data ["seq_ctx" ].input_ids .numel ())
263+ seqlen_list .append (data ["seq_ctx" ].input_ids .numel ()) # type: ignore[union-attr]
260264 total_length = sum (seqlen_list )
261265
262266 if total_length <= pack_max_length :
@@ -273,15 +277,10 @@ def _pack(self, mini_batch, pack_max_length):
273277 packed_mini_batches .append (packed_batch )
274278 return packed_mini_batches
275279
276- def _get_data_batches_properties (self , data_batches : list [ColateItem ]):
280+ def _set_data_batches_properties (self , data_batches : list [WorkerInputItem ]):
277281 """Extract properties from the first element of data_batches."""
278282 if not data_batches :
279- return {
280- "is_qwen3_vl" : False ,
281- "has_rollout_routed_experts" : False ,
282- "has_rollout_logprobs" : False ,
283- "n_routed_experts" : None ,
284- }
283+ return
285284
286285 first_item = data_batches [0 ]
287286 seq_ctx = first_item ["seq_ctx" ]
@@ -296,114 +295,128 @@ def _get_data_batches_properties(self, data_batches: list[ColateItem]):
296295 if isinstance (self .model_cfg , BaseComposeConfig ):
297296 language_cfg = self .model_cfg .text_config
298297
299- return {
300- "is_qwen3_vl" : is_qwen3_vl ,
301- "has_rollout_routed_experts" : has_rollout_routed_experts ,
302- "has_rollout_logprobs" : has_rollout_logprobs ,
303- "n_routed_experts" : language_cfg .n_routed_experts if language_cfg is not None else None ,
298+ self .is_qwen3_vl = is_qwen3_vl
299+ self .has_rollout_routed_experts = has_rollout_routed_experts
300+ self .has_rollout_logprobs = has_rollout_logprobs
301+ self .n_routed_experts = language_cfg .n_routed_experts if language_cfg is not None else None
302+
303+ def _pad_and_pack_batches (self , batch4pack : list [WorkerInputItem ], pack_max_length : int ) -> WorkerInputItem :
304+ seq_ctx_list = [item ["seq_ctx" ] for item in batch4pack ]
305+ label_list = [item ["shifted_labels" ] for item in batch4pack ]
306+ advantage_list = [torch .tensor ([item ["advantages" ]]).float ().unsqueeze (0 ) for item in batch4pack ]
307+ rollout_logprobs_list = [
308+ item ["rollout_logprobs" ] if self .has_rollout_logprobs else None for item in batch4pack
309+ ]
310+ cur_length = 0
311+ for item in batch4pack :
312+ cur_length += item ["seq_ctx" ].input_ids .numel () # type: ignore[union-attr]
313+ padding_len = pack_max_length - cur_length
314+
315+ if padding_len > 0 :
316+ padding_item = self ._create_padding_item (padding_len , pack_max_length )
317+ seq_ctx_list .append (padding_item ["seq_ctx" ])
318+ label_list .append (padding_item ["shifted_labels" ])
319+ advantage_list .append (padding_item ["advantages" ])
320+ rollout_logprobs_list .append (padding_item ["rollout_logprobs" ])
321+
322+ packed_seq_ctx = SequenceContext .pack (seq_ctx_list )
323+ packed_shifted_labels = torch .cat (label_list , dim = 1 ) # type: ignore[arg-type]
324+ packed_shifted_labels = cast (torch .LongTensor , packed_shifted_labels )
325+ cu_seq_lens_q = packed_seq_ctx .cu_seq_lens_q
326+ packed_num_tokens = cu_seq_lens_q [1 :] - cu_seq_lens_q [:- 1 ]
327+ packed_advantages = torch .cat (advantage_list , dim = 1 )
328+ packed_advantages = torch .repeat_interleave (packed_advantages , packed_num_tokens , dim = 1 )
329+ if self .has_rollout_logprobs :
330+ cast_rollout_logprobs_list = [cast (torch .Tensor , item ) for item in rollout_logprobs_list ]
331+ packed_rollout_logprobs = torch .cat (cast_rollout_logprobs_list , dim = 1 )
332+ else :
333+ packed_rollout_logprobs = None
334+
335+ optimizer_step_packs : WorkerInputItem = {
336+ "seq_ctx" : packed_seq_ctx ,
337+ "shifted_labels" : packed_shifted_labels ,
338+ "advantages" : packed_advantages ,
339+ "rollout_logprobs" : packed_rollout_logprobs ,
304340 }
341+ return optimizer_step_packs
342+
343+ def _pad_to_max_packs_across_workes (
344+ self ,
345+ packed_data_batches : list [list [list [WorkerInputItem ]]],
346+ step_idx : int ,
347+ max_packs : int ,
348+ pack_max_length : int ,
349+ ):
350+ for dp_rank in range (len (packed_data_batches )):
351+ num_current_packs = len (packed_data_batches [dp_rank ][step_idx ])
352+ num_padding_packs = max_packs - num_current_packs
353+
354+ if num_padding_packs > 0 :
355+ padding_item = self ._create_padding_item (pack_max_length , pack_max_length )
356+ padding_items = [padding_item for _ in range (num_padding_packs )]
357+ packed_data_batches [dp_rank ][step_idx ].extend (padding_items )
305358
306359 @ray_method
307360 def fit (
308- self , data_batches : list [ColateItem ], pack_max_length : int , rollout_idx : int , enable_dp_balance : bool = True
361+ self ,
362+ data_batches : list [WorkerInputItem ],
363+ pack_max_length : int ,
364+ rollout_idx : int ,
365+ enable_dp_balance : bool = True ,
309366 ):
310- batch_props = self ._get_data_batches_properties (data_batches )
311- is_qwen3_vl = batch_props ["is_qwen3_vl" ]
312- has_rollout_routed_experts = batch_props ["has_rollout_routed_experts" ]
313- has_rollout_logprobs = batch_props ["has_rollout_logprobs" ]
314- n_routed_experts = batch_props ["n_routed_experts" ]
367+ self ._set_data_batches_properties (data_batches )
315368
316369 world_size = len (self .workers )
317370 dp_size = world_size // self .data_replicate_size
318371 assert world_size % self .data_replicate_size == 0 , "world_size must be divisible by data_replicate_size"
319372 optimizer_steps = self .worker_cfg .optimizer_steps
320373
374+ batches_per_dp_group : list [list [WorkerInputItem ]]
321375 if enable_dp_balance :
322376 # 按照 dp_size 对数据进行重新分配,保证每个 dp rank 上的 token 数量大致相同
323377 batches_per_dp_group = self ._balance_split_batch (data_batches , dp_size )
324378 else :
325379 batches_per_dp_group = np .array_split (data_batches , dp_size )
326380 tokens_in_partition = []
327381 for batch in batches_per_dp_group :
328- tokens_in_partition .append (sum (data ["seq_ctx" ].input_ids .numel () for data in batch ))
329- get_logger ().info (f"default split into { dp_size } partitions with tokens: { tokens_in_partition } " )
330-
331- packed_data_batches : list [list [list [dict ]]] = [[[] for _ in range (optimizer_steps )] for _ in range (dp_size )]
332- max_packs_per_card = [0 ] * optimizer_steps
382+ dp_group_total_tokens = 0
383+ for data in batch :
384+ dp_group_total_tokens += data ["seq_ctx" ].input_ids .numel () # type: ignore[union-attr]
385+ tokens_in_partition .append (dp_group_total_tokens )
386+ self .logger .info (f"default split into { dp_size } partitions with tokens: { tokens_in_partition } " )
387+
388+ packed_data_batches : list [list [list [WorkerInputItem ]]] = [
389+ [[] for _ in range (optimizer_steps )] for _ in range (dp_size )
390+ ]
391+ max_packs_per_step = [0 ] * optimizer_steps
333392
334393 for dp_rank , dp_worker_data_batches in enumerate (batches_per_dp_group ):
335- # 每个worker 内部按照optimizer_steps将token均分
394+ # 每个worker内部按照optimizer_steps将token均分
336395 if enable_dp_balance :
337396 random .shuffle (dp_worker_data_batches )
338- mini_batch_for_steps = self ._balance_split_batch (dp_worker_data_batches , optimizer_steps )
397+ mini_batch_for_steps : list [list [WorkerInputItem ]] = self ._balance_split_batch (
398+ dp_worker_data_batches , optimizer_steps
399+ )
339400
340401 for step_idx , step_mini_batch in enumerate (mini_batch_for_steps ):
341- # pack
342- pack_mini_batch = self ._pack (step_mini_batch , pack_max_length )
343- if len (pack_mini_batch ) > max_packs_per_card [step_idx ]:
344- max_packs_per_card [step_idx ] = len (pack_mini_batch )
345-
346- for pack in pack_mini_batch :
347- seq_ctx_list = [item ["seq_ctx" ] for item in pack ]
348- label_list = [item ["shifted_labels" ] for item in pack ]
349- advantage_list = [torch .tensor ([item ["advantage" ]]).float ().unsqueeze (0 ) for item in pack ]
350- rollout_logprobs_list = [
351- item ["rollout_logprobs" ] if has_rollout_logprobs else None for item in pack
352- ]
353- padding_len = pack_max_length - sum ([item ["seq_ctx" ].input_ids .numel () for item in pack ])
354- if padding_len > 0 :
355- padding_sample = self ._create_padding_sample (
356- padding_len ,
357- pack_max_length ,
358- is_qwen3_vl = is_qwen3_vl ,
359- has_rollout_routed_experts = has_rollout_routed_experts ,
360- has_rollout_logprobs = has_rollout_logprobs ,
361- n_routed_experts = n_routed_experts ,
362- )
363- seq_ctx_list .append (padding_sample ["seq_ctx" ])
364- label_list .append (padding_sample ["shifted_labels" ])
365- advantage_list .append (padding_sample ["advantages" ])
366- rollout_logprobs_list .append (padding_sample ["rollout_logprobs" ])
367-
368- packed_seq_ctx = SequenceContext .pack (seq_ctx_list )
369- packed_shifted_labels = torch .cat (label_list , dim = 1 )
370- cu_seq_lens_q = packed_seq_ctx .cu_seq_lens_q
371- packed_num_tokens = cu_seq_lens_q [1 :] - cu_seq_lens_q [:- 1 ]
372- packed_advantages = torch .cat (advantage_list , dim = 1 )
373- packed_advantages = torch .repeat_interleave (packed_advantages , packed_num_tokens , dim = 1 )
374- if has_rollout_logprobs :
375- cast_rollout_logprobs_list = [cast (torch .Tensor , item ) for item in rollout_logprobs_list ]
376- packed_rollout_logprobs = torch .cat (cast_rollout_logprobs_list , dim = 1 )
377- else :
378- packed_rollout_logprobs = None
379- packed_data_batches [dp_rank ][step_idx ].append (
380- {
381- "seq_ctx" : packed_seq_ctx ,
382- "shifted_labels" : packed_shifted_labels ,
383- "advantages" : packed_advantages ,
384- "rollout_logprobs" : packed_rollout_logprobs ,
385- }
386- )
402+ # rearrange mini batch to fit into packs of pack_max_length
403+ batch4pack_list : list [list [WorkerInputItem ]] = self ._rearrange_batch_for_pack (
404+ step_mini_batch , pack_max_length
405+ )
406+ if len (batch4pack_list ) > max_packs_per_step [step_idx ]:
407+ max_packs_per_step [step_idx ] = len (batch4pack_list )
387408
388- get_logger ().info (f"Gradient accumulation steps: { max_packs_per_card } " )
389- # padding for each worker to have same number of packs
390- for dp_rank in range (dp_size ):
391- for step_idx in range (optimizer_steps ):
392- max_packs = max_packs_per_card [step_idx ]
393- num_current_packs = len (packed_data_batches [dp_rank ][step_idx ])
394- num_padding_packs = max_packs - num_current_packs
395-
396- if num_padding_packs > 0 :
397- padding_sample = self ._create_padding_sample (
398- pack_max_length ,
399- pack_max_length ,
400- is_qwen3_vl = is_qwen3_vl ,
401- has_rollout_routed_experts = has_rollout_routed_experts ,
402- has_rollout_logprobs = has_rollout_logprobs ,
403- n_routed_experts = n_routed_experts ,
404- )
405- padding_samples = [padding_sample for _ in range (num_padding_packs )]
406- packed_data_batches [dp_rank ][step_idx ].extend (padding_samples )
409+ for batch4pack in batch4pack_list :
410+ # pad and pack batches into a single optimizer step pack
411+ step_pack = self ._pad_and_pack_batches (batch4pack , pack_max_length )
412+ packed_data_batches [dp_rank ][step_idx ].append (step_pack )
413+
414+ self .logger .info (f"Gradient accumulation for each optimizer steps: { max_packs_per_step } " )
415+
416+ # padding for each worker to have same number of packs in each optimizer step
417+ for step_idx in range (optimizer_steps ):
418+ max_packs = max_packs_per_step [step_idx ]
419+ self ._pad_to_max_packs_across_workes (packed_data_batches , step_idx , max_packs , pack_max_length )
407420
408421 handles = []
409422 for worker_idx , worker in enumerate (self .workers ):
0 commit comments