Skip to content

Commit d81d3d1

Browse files
author
liuximeng
committed
fix codecheck
1 parent d6d12a0 commit d81d3d1

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

recipe/transfer_queue/agent_loop.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,5 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data
6969

7070
def create_transferqueue_client(self, controller_info, config):
7171
ray.get(
72-
[
73-
worker.create_transferqueue_client.remote(controller_info, config)
74-
for worker in self.agent_loop_workers
75-
]
72+
[worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers]
7673
)

recipe/transfer_queue/ray_trainer.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
from tqdm import tqdm
4242
from transfer_queue import (
4343
BatchMeta,
44-
TransferQueueController,
4544
SimpleStorageUnit,
45+
TransferQueueController,
4646
get_placement_group,
4747
process_zmq_server_info,
4848
)
@@ -416,12 +416,14 @@ def __init__(
416416
def _initialize_data_system(self):
417417
# 1. initialize TransferQueueStorage
418418
train_data_size = (
419-
self.config.data.train_batch_size * self.config.trainer.num_global_batch *
420-
self.config.actor_rollout_ref.rollout.n
419+
self.config.data.train_batch_size
420+
* self.config.trainer.num_global_batch
421+
* self.config.actor_rollout_ref.rollout.n
421422
)
422423
val_data_size = (
423-
self.val_batch_size * self.config.trainer.num_global_batch *
424-
self.config.actor_rollout_ref.rollout.val_kwargs.n
424+
self.val_batch_size
425+
* self.config.trainer.num_global_batch
426+
* self.config.actor_rollout_ref.rollout.val_kwargs.n
425427
)
426428

427429
total_storage_size = train_data_size + val_data_size
@@ -667,9 +669,7 @@ def _validate(self):
667669
if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model":
668670
return {}
669671

670-
asyncio.run(
671-
self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")
672-
)
672+
asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}"))
673673

674674
# Store original inputs
675675
batch_meta = asyncio.run(
@@ -958,9 +958,7 @@ def init_workers(self):
958958
config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
959959
)
960960

961-
self.async_rollout_manager.create_transferqueue_client(
962-
self.data_system_controller_info, self.config
963-
)
961+
self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config)
964962

965963
def _save_checkpoint(self):
966964
from verl.utils.fs import local_mkdir_safe

0 commit comments

Comments
 (0)