Skip to content

Commit 9013472

Browse files
LLLLxmmmliuximeng
andauthored
Adapt for multimodal data format (#42)
Co-authored-by: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com>
1 parent 7a7a0c1 commit 9013472

File tree

1 file changed

+15
-77
lines changed

1 file changed

+15
-77
lines changed

recipe/transfer_queue/ray_trainer.py

Lines changed: 15 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -748,48 +748,15 @@ def _validate(self):
748748
ground_truths = [item.get("ground_truth", None) for item in data.get("reward_model", {})]
749749
sample_gts.extend(ground_truths)
750750

751-
if not self.async_rollout_mode:
752-
test_gen_meta = asyncio.run(
753-
self.val_data_system_client.async_get_meta(
754-
data_fields=[
755-
"input_ids",
756-
"attention_mask",
757-
"position_ids",
758-
"index",
759-
"tools_kwargs",
760-
"interaction_kwargs",
761-
"ability",
762-
"raw_prompt_ids",
763-
],
764-
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
765-
global_step=self.global_steps - 1, # self.global_steps start from 1
766-
get_n_samples=False,
767-
task_name="generate_sequences",
768-
)
769-
)
770-
else:
771-
test_gen_meta = asyncio.run(
772-
self.val_data_system_client.async_get_meta(
773-
data_fields=[
774-
"input_ids",
775-
"attention_mask",
776-
"position_ids",
777-
"index",
778-
"tools_kwargs",
779-
"interaction_kwargs",
780-
"ability",
781-
"raw_prompt_ids",
782-
"raw_prompt",
783-
"reward_model",
784-
"data_source",
785-
],
786-
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
787-
global_step=self.global_steps - 1, # self.global_steps start from 1
788-
get_n_samples=False,
789-
task_name="async_generate_sequences",
790-
)
751+
test_gen_meta = asyncio.run(
752+
self.val_data_system_client.async_get_meta(
753+
data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields
754+
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
755+
global_step=self.global_steps - 1, # self.global_steps start from 1
756+
get_n_samples=False,
757+
task_name="generate_sequences",
791758
)
792-
759+
)
793760
test_gen_meta.extra_info = {
794761
"eos_token_id": self.tokenizer.eos_token_id,
795762
"pad_token_id": self.tokenizer.pad_token_id,
@@ -1367,43 +1334,14 @@ def fit(self):
13671334
)
13681335
batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict)
13691336
asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1))
1370-
if not self.async_rollout_mode:
1371-
gen_meta = asyncio.run(
1372-
self.data_system_client.async_get_meta(
1373-
data_fields=[
1374-
"input_ids",
1375-
"attention_mask",
1376-
"position_ids",
1377-
"index",
1378-
"tools_kwargs",
1379-
"interaction_kwargs",
1380-
"ability",
1381-
"raw_prompt_ids",
1382-
],
1383-
task_name="generate_sequences",
1384-
**base_get_meta_kwargs,
1385-
)
1386-
)
1387-
else:
1388-
gen_meta = asyncio.run(
1389-
self.data_system_client.async_get_meta(
1390-
data_fields=[
1391-
"input_ids",
1392-
"attention_mask",
1393-
"position_ids",
1394-
"index",
1395-
"tools_kwargs",
1396-
"interaction_kwargs",
1397-
"ability",
1398-
"raw_prompt_ids",
1399-
"raw_prompt",
1400-
"reward_model",
1401-
"data_source",
1402-
],
1403-
task_name="async_generate_sequences",
1404-
**base_get_meta_kwargs,
1405-
)
1337+
1338+
gen_meta = asyncio.run(
1339+
self.data_system_client.async_get_meta(
1340+
data_fields=list(batch.keys()), # TODO: (TQ) Get metadata by specified fields
1341+
task_name="generate_sequences",
1342+
**base_get_meta_kwargs,
14061343
)
1344+
)
14071345
# pass global_steps to trace
14081346
gen_meta.set_extra_info("global_steps", self.global_steps)
14091347

0 commit comments

Comments
 (0)