|
41 | 41 | from tqdm import tqdm |
42 | 42 | from transfer_queue import ( |
43 | 43 | BatchMeta, |
44 | | - TransferQueueController, |
45 | 44 | SimpleStorageUnit, |
| 45 | + TransferQueueController, |
46 | 46 | get_placement_group, |
47 | 47 | process_zmq_server_info, |
48 | 48 | ) |
@@ -416,12 +416,14 @@ def __init__( |
416 | 416 | def _initialize_data_system(self): |
417 | 417 | # 1. initialize TransferQueueStorage |
418 | 418 | train_data_size = ( |
419 | | - self.config.data.train_batch_size * self.config.trainer.num_global_batch * |
420 | | - self.config.actor_rollout_ref.rollout.n |
| 419 | + self.config.data.train_batch_size |
| 420 | + * self.config.trainer.num_global_batch |
| 421 | + * self.config.actor_rollout_ref.rollout.n |
421 | 422 | ) |
422 | 423 | val_data_size = ( |
423 | | - self.val_batch_size * self.config.trainer.num_global_batch * |
424 | | - self.config.actor_rollout_ref.rollout.val_kwargs.n |
| 424 | + self.val_batch_size |
| 425 | + * self.config.trainer.num_global_batch |
| 426 | + * self.config.actor_rollout_ref.rollout.val_kwargs.n |
425 | 427 | ) |
426 | 428 |
|
427 | 429 | total_storage_size = train_data_size + val_data_size |
@@ -667,9 +669,7 @@ def _validate(self): |
667 | 669 | if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model": |
668 | 670 | return {} |
669 | 671 |
|
670 | | - asyncio.run( |
671 | | - self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}") |
672 | | - ) |
| 672 | + asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")) |
673 | 673 |
|
674 | 674 | # Store original inputs |
675 | 675 | batch_meta = asyncio.run( |
@@ -958,9 +958,7 @@ def init_workers(self): |
958 | 958 | config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg |
959 | 959 | ) |
960 | 960 |
|
961 | | - self.async_rollout_manager.create_transferqueue_client( |
962 | | - self.data_system_controller_info, self.config |
963 | | - ) |
| 961 | + self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config) |
964 | 962 |
|
965 | 963 | def _save_checkpoint(self): |
966 | 964 | from verl.utils.fs import local_mkdir_safe |
|
0 commit comments