@@ -55,7 +55,6 @@ def setUp(self):
5555 rewards = [item ['reward' ] for item in group ]
5656 rewards = torch .tensor (rewards , dtype = torch .float32 )
5757 advantages = (rewards - rewards .mean (0 )) / (rewards .std (0 ) + 1e-8 )
58-
5958 for i in range (self .prompt_repeat_k ):
6059 item = group [i ]
6160 response_ids = tokenizer (item ['response' ], return_tensors = 'pt' )['input_ids' ].flatten ().tolist ()
@@ -67,7 +66,7 @@ def setUp(self):
6766 dict (
6867 seq_ctx = SequenceContext .from_input_ids ((input_ids , ), device = "cpu" ),
6968 shifted_labels = shifted_labels ,
70- advantage = advantages [i ]. item () ,
69+ advantages = advantages [i ],
7170 )
7271 )
7372 self .data_batches = data_batches
@@ -126,8 +125,125 @@ def build_train_controller(self):
126125 ray .get (train_controller .__ray_ready__ .remote ())
127126 return train_controller
128127
129- def test_grpo_train_and_save (self ):
128+ # def test_grpo_train_and_save(self):
129+ # train_controller = self.build_train_controller()
130+ # ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=8192, rollout_idx=0))
131+ # save_path = os.path.join(self.temp_dir, "hf_test")
132+ # ray.get(train_controller.save_hf.remote(str(save_path)))
133+
134+ def _create_dummy_item (self , length : int ):
135+ """Helper to create a dummy WorkerInputItem"""
136+ input_ids = torch .ones (1 , length , dtype = torch .long )
137+ cu_seq_lens_q = torch .tensor ([0 , length ], dtype = torch .int32 )
138+ cu_seq_lens_k = torch .tensor ([0 , length ], dtype = torch .int32 )
139+ max_length_q = torch .tensor (length , dtype = torch .int32 )
140+ max_length_k = torch .tensor (length , dtype = torch .int32 )
141+ seq_ctx = SequenceContext (
142+ input_ids = input_ids ,
143+ cu_seq_lens_q = cu_seq_lens_q ,
144+ cu_seq_lens_k = cu_seq_lens_k ,
145+ max_length_q = max_length_q ,
146+ max_length_k = max_length_k ,
147+ num_padding = 0 ,
148+ device = "cpu" ,
149+ )
150+ return {
151+ "seq_ctx" : seq_ctx ,
152+ "shifted_labels" : torch .ones (1 , length , dtype = torch .long ),
153+ "advantages" : torch .rand (1 , 1 , dtype = torch .float ),
154+ "rollout_logprobs" : torch .ones (1 , length , dtype = torch .float ),
155+ }
156+
157+ def test_controller_logic (self ):
158+ """
159+ Unit tests for RawTrainingController internal logic using the real Ray actor:
160+ - _balance_split_batch
161+ - _create_padding_item
162+ - _rearrange_batch_for_pack
163+ - _pad_and_pack_batches
164+ """
165+ # 1. Build the real train controller
130166 train_controller = self .build_train_controller ()
131- ray .get (train_controller .fit .remote (self .data_batches , pack_max_length = 1024 , rollout_idx = 0 ))
132- save_path = os .path .join (self .temp_dir , "hf_test" )
133- ray .get (train_controller .save_hf .remote (str (save_path )))
167+ pack_max_length = 100
168+
169+ # --- Test 1: _balance_split_batch ---
170+ print ("Testing _balance_split_batch..." )
171+ # Input: 4 items with lengths 10, 20, 30, 40
172+ items = [self ._create_dummy_item (l ) for l in [10 , 20 , 30 , 40 ]]
173+ dp_size = 2
174+
175+ # Call remote method
176+ # 10, 20, 30, 40 -> sum 100 -> avg 50.
177+ # Expected split: [10, 40] (sum 50) and [20, 30] (sum 50)
178+ result = ray .get (train_controller ._balance_split_batch .remote (items , dp_size ))
179+
180+ self .assertEqual (len (result ), 2 )
181+ self .assertEqual (len (result [0 ]), 2 )
182+ self .assertEqual (len (result [1 ]), 2 )
183+
184+ # Verify balance
185+ len_group0 = sum (item ["seq_ctx" ].input_ids .shape [1 ] for item in result [0 ])
186+ len_group1 = sum (item ["seq_ctx" ].input_ids .shape [1 ] for item in result [1 ])
187+ self .assertEqual (len_group0 , 50 )
188+ self .assertEqual (len_group1 , 50 )
189+
190+ # --- Test 2: _rearrange_batch_for_pack ---
191+ print ("Testing _rearrange_batch_for_pack..." )
192+ # Input: [40, 40, 30], max=100. With get_seqlen_balanced_partitions, it should be packed as [40, 30] and [40]
193+ items_pack = [self ._create_dummy_item (l ) for l in [40 , 40 , 30 ]]
194+ batches = ray .get (train_controller ._rearrange_batch_for_pack .remote (items_pack , pack_max_length ))
195+
196+ self .assertEqual (len (batches ), 2 )
197+ self .assertEqual (len (batches [0 ]), 2 ) # 40 + 30 = 70
198+ self .assertEqual (len (batches [1 ]), 1 ) # 40
199+ self .assertEqual (batches [0 ][0 ]["seq_ctx" ].input_ids .shape [1 ] + batches [0 ][1 ]["seq_ctx" ].input_ids .shape [1 ], 70 )
200+ self .assertEqual (batches [1 ][0 ]["seq_ctx" ].input_ids .shape [1 ], 40 )
201+ # --- Test 3: _pad_and_pack_batches ---
202+ print ("Testing _pad_and_pack_batches..." )
203+ # Input: First batch with length 70. Should pad 30 to reach 100. Second batch with length 40, should pad 60 to reach 100.
204+ for idx , batch4pack_list in enumerate (batches ):
205+ packed_item = ray .get (train_controller ._pad_and_pack_batches .remote (batch4pack_list , pack_max_length ))
206+ # Check total length
207+ self .assertEqual (packed_item ["seq_ctx" ].input_ids .shape [1 ], pack_max_length )
208+ # idx == 0:
209+ if idx == 0 :
210+ # Check cu_seq_lens_q: [0, 40, 70, 100]
211+ expected_cu_lens = torch .tensor ([0 , 40 , 70 , 100 ], dtype = torch .int32 )
212+ self .assertTrue (torch .equal (packed_item ["seq_ctx" ].cu_seq_lens_q , expected_cu_lens ))
213+ # Check padding labels are -100
214+ self .assertTrue (torch .all (packed_item ["shifted_labels" ][0 , 70 :] == - 100 ))
215+ if idx == 1 :
216+ # Check cu_seq_lens_q: [0, 40, 100]
217+ expected_cu_lens = torch .tensor ([0 , 40 , 100 ], dtype = torch .int32 )
218+ self .assertTrue (torch .equal (packed_item ["seq_ctx" ].cu_seq_lens_q , expected_cu_lens ))
219+ # Check padding labels are -100
220+ self .assertTrue (torch .all (packed_item ["shifted_labels" ][0 , 40 :] == - 100 ))
221+
222+ # --- Test 4: _pad_to_max_packs_across_workes ---
223+ pack_dummy = {"dummy" : "pack" }
224+ packed_data_batches = [
225+ [[pack_dummy , pack_dummy ]], # Worker 0: 2 packs
226+ [[pack_dummy ]] # Worker 1: 1 pack
227+ ]
228+ # Execute the function locally
229+ packed_data_batches = ray .get (train_controller ._pad_to_max_packs_across_workes .remote (
230+ packed_data_batches , 0 , 2 , pack_max_length
231+ ))
232+ # Verification
233+ # Worker 0 should still have 2 packs
234+ self .assertEqual (len (packed_data_batches [0 ][0 ]), 2 )
235+
236+ # Worker 1 should now have 2 packs (1 original + 1 padding)
237+ self .assertEqual (len (packed_data_batches [1 ][0 ]), 2 )
238+
239+ # Verify the added item is a padding item
240+ added_pack = packed_data_batches [1 ][0 ][1 ]
241+ # Since we used the real _create_padding_item, it should have the correct structure
242+ self .assertIn ("seq_ctx" , added_pack )
243+ self .assertIn ("shifted_labels" , added_pack )
244+ self .assertEqual (added_pack ["seq_ctx" ].input_ids .shape [1 ], pack_max_length )
245+ self .assertTrue (torch .all (added_pack ["shifted_labels" ] == - 100 ))
246+ print ("All controller logic tests passed!" )
247+
248+ if __name__ == "__main__" :
249+ unittest .main ()
0 commit comments