@@ -748,48 +748,15 @@ def _validate(self):
748748 ground_truths = [item .get ("ground_truth" , None ) for item in data .get ("reward_model" , {})]
749749 sample_gts .extend (ground_truths )
750750
751- if not self .async_rollout_mode :
752- test_gen_meta = asyncio .run (
753- self .val_data_system_client .async_get_meta (
754- data_fields = [
755- "input_ids" ,
756- "attention_mask" ,
757- "position_ids" ,
758- "index" ,
759- "tools_kwargs" ,
760- "interaction_kwargs" ,
761- "ability" ,
762- "raw_prompt_ids" ,
763- ],
764- batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
765- global_step = self .global_steps - 1 , # self.global_steps start from 1
766- get_n_samples = False ,
767- task_name = "generate_sequences" ,
768- )
769- )
770- else :
771- test_gen_meta = asyncio .run (
772- self .val_data_system_client .async_get_meta (
773- data_fields = [
774- "input_ids" ,
775- "attention_mask" ,
776- "position_ids" ,
777- "index" ,
778- "tools_kwargs" ,
779- "interaction_kwargs" ,
780- "ability" ,
781- "raw_prompt_ids" ,
782- "raw_prompt" ,
783- "reward_model" ,
784- "data_source" ,
785- ],
786- batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
787- global_step = self .global_steps - 1 , # self.global_steps start from 1
788- get_n_samples = False ,
789- task_name = "async_generate_sequences" ,
790- )
751+ test_gen_meta = asyncio .run (
752+ self .val_data_system_client .async_get_meta (
753+ data_fields = list (test_batch .keys ()), # TODO: (TQ) Get metadata by specified fields
754+ batch_size = self .val_batch_size * self .config .actor_rollout_ref .rollout .val_kwargs .n ,
755+ global_step = self .global_steps - 1 , # self.global_steps start from 1
756+ get_n_samples = False ,
757+ task_name = "generate_sequences" ,
791758 )
792-
759+ )
793760 test_gen_meta .extra_info = {
794761 "eos_token_id" : self .tokenizer .eos_token_id ,
795762 "pad_token_id" : self .tokenizer .pad_token_id ,
@@ -1367,43 +1334,14 @@ def fit(self):
13671334 )
13681335 batch : TensorDict = self .dict_to_tensordict (repeated_batch_dict )
13691336 asyncio .run (self .data_system_client .async_put (data = batch , global_step = self .global_steps - 1 ))
1370- if not self .async_rollout_mode :
1371- gen_meta = asyncio .run (
1372- self .data_system_client .async_get_meta (
1373- data_fields = [
1374- "input_ids" ,
1375- "attention_mask" ,
1376- "position_ids" ,
1377- "index" ,
1378- "tools_kwargs" ,
1379- "interaction_kwargs" ,
1380- "ability" ,
1381- "raw_prompt_ids" ,
1382- ],
1383- task_name = "generate_sequences" ,
1384- ** base_get_meta_kwargs ,
1385- )
1386- )
1387- else :
1388- gen_meta = asyncio .run (
1389- self .data_system_client .async_get_meta (
1390- data_fields = [
1391- "input_ids" ,
1392- "attention_mask" ,
1393- "position_ids" ,
1394- "index" ,
1395- "tools_kwargs" ,
1396- "interaction_kwargs" ,
1397- "ability" ,
1398- "raw_prompt_ids" ,
1399- "raw_prompt" ,
1400- "reward_model" ,
1401- "data_source" ,
1402- ],
1403- task_name = "async_generate_sequences" ,
1404- ** base_get_meta_kwargs ,
1405- )
1337+
1338+ gen_meta = asyncio .run (
1339+ self .data_system_client .async_get_meta (
1340+ data_fields = list (batch .keys ()), # TODO: (TQ) Get metadata by specified fields
1341+ task_name = "generate_sequences" ,
1342+ ** base_get_meta_kwargs ,
14061343 )
1344+ )
14071345 # pass global_steps to trace
14081346 gen_meta .set_extra_info ("global_steps" , self .global_steps )
14091347
0 commit comments