1+ import unittest
2+ import torch
3+ from xtuner .v1 .data_proto .sequence_context import SequenceContext
4+ from xtuner .v1 .rl .base .pack import DataBatchPacker
5+
6+ class TestDataBatchPacker (unittest .TestCase ):
7+ def setUp (self ):
8+ self .pack_max_length = 3072
9+ self .split_size = 1024
10+
11+ def _create_dummy_item (self , length : int , val = 1 ):
12+ input_ids = torch .full ((1 , length ), val , dtype = torch .long )
13+ cu_seq_lens_q = torch .tensor ([0 , length ], dtype = torch .int32 )
14+ cu_seq_lens_k = torch .tensor ([0 , length ], dtype = torch .int32 )
15+ max_length_q = torch .tensor (length , dtype = torch .int32 )
16+ max_length_k = torch .tensor (length , dtype = torch .int32 )
17+ seq_ctx = SequenceContext (
18+ input_ids = input_ids ,
19+ cu_seq_lens_q = cu_seq_lens_q ,
20+ cu_seq_lens_k = cu_seq_lens_k ,
21+ max_length_q = max_length_q ,
22+ max_length_k = max_length_k ,
23+ num_padding = 0 ,
24+ device = "cpu" ,
25+ )
26+ return {
27+ "seq_ctx" : seq_ctx ,
28+ "shifted_labels" : torch .full ((1 , length ), val , dtype = torch .long ),
29+ "advantages" : torch .full ((1 , length ), float (val ), dtype = torch .float ),
30+ "rollout_logprobs" : torch .full ((1 , length ), float (val ), dtype = torch .float ),
31+ }
32+
33+ def _run_strategy_test (self , strategy , world_size , optimizer_steps , lengths , expected_padding ):
34+ data_batches = [self ._create_dummy_item (l , val = 7 ) for l in lengths ]
35+ total_data_tokens = sum (lengths )
36+
37+ packer = DataBatchPacker (
38+ pack_max_length = self .pack_max_length ,
39+ world_size = world_size ,
40+ data_replicate_size = 1 ,
41+ optimizer_steps = optimizer_steps ,
42+ pack_strategy = strategy
43+ )
44+
45+ packed_res , padding_tokens = packer .pack (data_batches )
46+
47+ self .assertEqual (padding_tokens , expected_padding ,
48+ f"Strategy { strategy } padding mismatch. Expected { expected_padding } , got { padding_tokens } " )
49+
50+ all_packs = []
51+ for rank_data in packed_res :
52+ for step_data in rank_data :
53+ all_packs .extend (step_data )
54+
55+ total_capacity = len (all_packs ) * self .pack_max_length
56+ self .assertEqual (total_capacity , total_data_tokens + padding_tokens )
57+
58+ valid_token_count = sum ((p ["seq_ctx" ].input_ids != 0 ).sum ().item () for p in all_packs )
59+ valid_label_count = sum ((p ["shifted_labels" ] != - 100 ).sum ().item () for p in all_packs )
60+ valid_adv_count = sum ((p ["advantages" ] != - 100 ).sum ().item () for p in all_packs )
61+
62+ self .assertEqual (valid_token_count , total_data_tokens )
63+ self .assertEqual (valid_label_count , total_data_tokens )
64+ self .assertEqual (valid_adv_count , total_data_tokens )
65+
66+ def test_variable_packs (self ):
67+ """随机tokens数输入, dp=2, optimizer_steps=2
68+ - Native:
69+ 1. DP Rank 切分:
70+ rank0: [1500, 1000, 2800, 1500]
71+ rank1: [2000, 2100, 1000, 800]
72+ 2. Optimizer steps切分:
73+ rank0: [1500, 1000], [2800, 1500]
74+ rank1: [2000, 2100], [1000, 800]
75+ 3 pack and padding
76+ rank0: step0: [2500 -> 3072], step1: [2800 -> 3072], [1500 -> 3072],
77+ rank1: step0: [2100 -> 3072], [2000 -> 3072], step1: [1800 -> 3072]
78+ 4. 跨卡对齐pack数量:
79+ rank0: step0: [2500 -> 3072], [0 -> 3072] step1: [2800 -> 3072], [1500 -> 3072],
80+ rank1: step0: [2100 -> 3072], [2000 -> 3072], step1: [1800 -> 3072], [0 -> 3072]
81+ padding_tokens: 3072 - 2500 + 3072 + 3072 - 2800 + 3072 - 1500 + 3072 - 2100 + 3072 - 2000 + 3072 - 1800 + 3072 = 11876
82+ - Balance:
83+ 1. DP Rank 均衡切分:
84+ rank0: [2800, 1500, 1000, 1000]
85+ rank1: [2000, 2100, 1500, 800]
86+ 2. optimizer_steps 均衡切分:
87+ rank0: [2800], [1500, 1000, 1000]
88+ rank1: [2100, 800], [2000, 1500]
89+ 3. pack and padding:
90+ rank0: step0: [2800 -> 3072], step1: [2500 -> 3072], [1000 -> 3072],
91+ rank1: step0: [2900 -> 3072], step1: [2000 -> 3072], [1500 -> 3072]
92+ 4. 跨卡对齐pack数量:
93+ skip
94+ padding_tokens: 3072 - 2800 + 3072 - 2500 + 3072 - 1000 + 3072 - 2900 + 3072 - 2000 + 3072 - 1500 = 5732
95+ - Greedy (贪心优先打包): 追求 Pack 填充率最大化,不维护样本在 Step 间的顺序
96+ 1. pack and padding:
97+ Pack 1: [1500, 1000] -> [2500 -> 3072]
98+ Pack 2: [2800] -> [2800 -> 3072]
99+ Pack 3: [1500] -> [1500 -> 3072]
100+ Pack 4: [2000] -> [2000 -> 3072]
101+ Pack 5: [2100] -> [2100 -> 3072]
102+ Pack 6: [1000, 800] -> [1800 -> 3072]
103+ 2. DP 切分:
104+ rank0: [Pack 1, Pack 2, Pack 3]
105+ rank1: [Pack 4, Pack 5, Pack 6]
106+ 3. Opitmizer steps 切分:
107+ rank0: step0: [Pack 1, Pack 2], step1: [Pack 3]
108+ rank1: step0: [Pack 4, Pack 5], step1: [Pack 6]
109+ 4. 跨卡对齐pack数量:
110+ skip
111+ padding_tokens: 3072 - 2500 + 3072 - 2800 + 3072 - 1500 + 3072 - 2000 + 3072 - 2100 + 3072 - 1800 = 5732
112+ """
113+ lengths = [1500 , 1000 , 2800 , 1500 , 2000 , 2100 , 1000 , 800 ]
114+ self ._run_strategy_test ("native" , 2 , 2 , lengths , 11876 )
115+ self ._run_strategy_test ("balance" , 2 , 2 , lengths , 5732 )
116+ self ._run_strategy_test ("greedy" , 2 , 2 , lengths , 5732 )
117+
118+ def test_imbalance_dp_size (self ):
119+ lengths = [500 ]
120+ for strat in ["native" , "balance" , "greedy" ]:
121+ self ._run_strategy_test (strat , 2 , 1 , lengths , 5644 )
122+
123+ def test_imbalanced_steps (self ):
124+ lengths = [100 , 200 , 2500 , 3000 , 50 , 400 , 1000 , 1500 ]
125+ self ._run_strategy_test ("native" , 2 , 4 , lengths , 15826 )
126+ self ._run_strategy_test ("balance" , 2 , 4 , lengths , 15826 )
127+ self ._run_strategy_test ("greedy" , 2 , 4 , lengths , 3538 )
128+
129+
130+ if __name__ == "__main__" :
131+ unittest .main ()
0 commit comments