Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,56 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)


class TestStepAheadAsyncRL(BaseTrainerCase):
def test_trainer(self):
"""Test the explore step ahead trainer"""
# train 4 step, sync_offset=1, sync_interval=2
# Explorer:
# | 1 | 2 | 3 |sync| 4 |
# |---|---|---|sync|---|
# Trainer:
# | 1 | 2 |sync| 3 | 4 |
# |---|---|sync|---|---|
self.config.buffer.total_epochs = 1
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.trainer.save_interval = 4
self.config.synchronizer.sync_interval = 2
self.config.synchronizer.sync_offset = 1
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 1
self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 1

both(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
actor_kl_metrics = parser.metric_list("actor/kl")
self.assertTrue(len(actor_kl_metrics) > 0)
critic_kl_metrics = parser.metric_list("critic/kl")
self.assertTrue(len(critic_kl_metrics) > 0)
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint
from trinity.common.models.utils import get_checkpoint_dir_with_step_num

checkpoint_step_4 = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=4,
)
self.assertTrue(os.path.exists(checkpoint_step_4))

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerGSM8K(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K."""
Expand Down Expand Up @@ -153,7 +203,7 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerGSM8KWithSFT(BaseTrainerCase):
class TestTrainerSFTWarmupGSM8K(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K With SFT."""
# test both mode
Expand Down
36 changes: 28 additions & 8 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ray

from trinity.common.config import Config, load_config
from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.log import get_logger
Expand All @@ -19,7 +20,7 @@

def bench(config: Config) -> None:
"""Evaluate model."""
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.benchmark.remote())
Expand All @@ -33,7 +34,7 @@ def bench(config: Config) -> None:
def explore(config: Config) -> None:
"""Run explorer."""
try:
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
ray.get(explorer.explore.remote())
Expand All @@ -46,7 +47,7 @@ def explore(config: Config) -> None:
def train(config: Config) -> None:
"""Run trainer."""
try:
trainer = ray.remote(Trainer).options(name="trainer").remote(config)
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
ray.get(trainer.prepare.remote())
ray.get(trainer.sync_weight.remote())
ray.get(trainer.train.remote())
Expand All @@ -66,8 +67,8 @@ def both(config: Config) -> None:
the latest step. The specific number of experiences may vary for different
algorithms and tasks.
"""
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
trainer = ray.remote(Trainer).options(name="trainer").remote(config)
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
ray.get(
[
Expand All @@ -81,15 +82,34 @@ def both(config: Config) -> None:
trainer.sync_weight.remote(),
]
)
_, _ = ray.wait(
ready_ref, wait_ref = ray.wait(
[
explorer.explore.remote(),
trainer.train.remote(),
],
num_returns=1,
)
explorer.shutdown.remote(),
trainer.shutdown.remote(),

ready = ray.get(ready_ref[0])
if ready == TRAINER_NAME:
logger.info(
"===========================================================\n"
"> Launcher detected that the `Trainer` process has finished.\n"
"> Stopping the explorer process immediately.\n"
"==========================================================="
)
ray.wait(wait_ref, timeout=5)
elif ready == EXPLORER_NAME:
logger.info(
"============================================================\n"
"> Launcher detected that the `Explorer` process has finished.\n"
f"> Waiting {config.synchronizer.sync_timeout} s for the trainer process...\n"
"> You can force stop the Trainer process by pressing Ctrl+C.\n"
"============================================================"
)
ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout)
explorer.shutdown.remote()
trainer.shutdown.remote()


def activate_data_module(data_workflow_url: str, config_path: str):
Expand Down
11 changes: 11 additions & 0 deletions trinity/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

# names

EXPLORER_NAME = "explorer"
TRAINER_NAME = "trainer"

ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync"


Expand Down Expand Up @@ -92,3 +95,11 @@ class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta):

NCCL = "nccl"
CHECKPOINT = "checkpoint"


class RunningStatus(Enum):
"""Running status of explorer and trainer."""

RUNNING = "running"
WAITING_SYNC = "waiting_sync"
STOPPED = "stopped"
3 changes: 2 additions & 1 deletion trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed

from trinity.common.constants import EXPLORER_NAME
from trinity.utils.distributed import init_process_group, is_ipv6_address
from trinity.utils.log import get_logger

Expand Down Expand Up @@ -60,7 +61,7 @@ def update_weight(self):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
assert self._state_dict_meta is not None
if self._explorer_actor is None:
self._explorer_actor = ray.get_actor(name="explorer")
self._explorer_actor = ray.get_actor(name=EXPLORER_NAME)
for name, dtype_str, shape in self._state_dict_meta:
if self._weight_update_rank == 0:
weight = ray.get(self._explorer_actor.get_weight.remote(name))
Expand Down
48 changes: 36 additions & 12 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from trinity.buffer import get_buffer_writer
from trinity.buffer.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
from trinity.common.constants import (
EXPLORER_NAME,
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
RunningStatus,
SyncMethod,
)
from trinity.common.models import create_inference_models
from trinity.common.models.utils import (
get_checkpoint_dir_with_step_num,
Expand Down Expand Up @@ -50,7 +55,7 @@ def __init__(self, config: Config):
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
name=self.config.name,
role="explorer",
role=EXPLORER_NAME,
config=config,
)
self.batch_size = config.buffer.batch_size
Expand All @@ -69,6 +74,7 @@ def __init__(self, config: Config):
self.state_dict = {}
else: # nccl mode
self.state_dict_meta = []
self.status = RunningStatus.RUNNING
self.logger.info("Finished initializing Explorer.")

async def setup_weight_sync_group(
Expand Down Expand Up @@ -162,35 +168,44 @@ async def get_weight(self, name: str) -> torch.Tensor:
"""Get the weight of the loaded model (For checkpoint weights update)."""
return self.state_dict[name]

async def explore(self) -> None:
async def explore(self) -> str:
while True:
try:
explore_contionue = self.explore_step()
if not explore_contionue:
break
if self.need_sync():
self.wait_for_workflow_done()
await self.sync_weight()
if self.explore_step_num % self.config.explorer.eval_interval == 0:
self.wait_for_workflow_done()
self.eval()
if not explore_contionue:
break
except Exception as e:
self.logger.error(f"Error in Explorer: {e}")
break
self.logger.info("--------------------\n> Explorer finished.\n--------------------\n")
self.logger.info("--------------------\n> Explorer finished.\n--------------------")
return EXPLORER_NAME

def explore_step(self) -> bool:
self.explore_step_num += 1
algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num)
algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1)
# skip warmup
if algo_config.algorithm_type == "sft":
self.explore_step_num += 1
return True
try:
tasks = self.taskset.read()
except StopIteration:
self.logger.warning("No more tasks to explore. Stop exploring.")
self.cache.save_explorer(
current_step=self.explore_step_num,
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
)
self.status = RunningStatus.STOPPED
self.wait_for_workflow_done()
self.experience_buffer.finish()
return False
self.runner_pool.run_tasks(tasks)
self.explore_step_num += 1
return True

def need_sync(self) -> bool:
Expand Down Expand Up @@ -278,20 +293,25 @@ def wait_for_workflow_done(self) -> None:
if not status.ok:
self.logger.error(f"Error when running task: {status.message}")
# submit another task to replace the failed task
self.runner_pool.run_tasks(self.taskset.read(batch_size=1))
try:
tasks = self.taskset.read(batch_size=1)
except StopIteration:
self.logger.warning("No more tasks in taskset. Stop retrying.")
return
self.runner_pool.run_tasks(tasks)
else:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)
# calculate metrics
log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore
self.monitor.log(log_metrics, step=self.explore_step_num)

self.logger.info(f"Explore step {self.explore_step_num} finished.")

async def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
self.logger.info(f"Explorer synchronizing weights at step {self.explore_step_num}.")
self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.")
self.status = RunningStatus.WAITING_SYNC
if self.use_checkpoint_weights_update:
await self._checkpoint_weights_update()
else: # nccl weights update
Expand All @@ -301,7 +321,11 @@ async def sync_weight(self) -> None:
current_step=self.explore_step_num,
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
)
self.logger.info(f"Explorer synchronizing at step {self.explore_step_num} finished")
self.status = RunningStatus.RUNNING
self.logger.info(f"Explorer sync at step {self.explore_step_num} finished")

async def running_status(self) -> RunningStatus:
return self.status

def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
Expand Down
25 changes: 20 additions & 5 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
import os
from abc import ABC, abstractmethod

import ray

from trinity.common.config import Config
from trinity.common.constants import SyncMethod
from trinity.common.constants import (
EXPLORER_NAME,
TRAINER_NAME,
RunningStatus,
SyncMethod,
)
from trinity.utils.log import get_logger


Expand All @@ -19,24 +26,26 @@ def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
self.engine = get_trainer_wrapper(config)
self.explorer_ref = None

def prepare(self) -> None:
"""Prepare the trainer."""
self.engine.prepare()

def train(self):
def train(self) -> str:
"""Train the model."""
while True:
try:
train_continue = self.train_step()
if self.need_sync():
self.sync_weight()
if not train_continue:
break
if self.need_sync():
self.sync_weight()
except Exception as e:
self.logger.error(f"Error in Trainer: {e}")
break
self.logger.info("--------------------\n> Trainer finished.\n--------------------\n")
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
return TRAINER_NAME

def train_step(self) -> bool:
"""Train one step.
Expand All @@ -53,6 +62,12 @@ def need_sync(self) -> bool:
def sync_weight(self) -> None:
"""Sync the model weight."""
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
if self.explorer_ref is None:
self.explorer_ref = ray.get_actor(EXPLORER_NAME)
explorer_status = ray.get(self.explorer_ref.running_status.remote())
if explorer_status == RunningStatus.STOPPED:
self.logger.warning("Explorer has already stopped. Skipping sync weight.")
return
self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.")
self.engine.sync_weight()

Expand Down
8 changes: 6 additions & 2 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

from trinity.common.config import AlgorithmConfig
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
from trinity.common.constants import (
EXPLORER_NAME,
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
SyncMethod,
)
from trinity.utils.distributed import init_process_group, is_ipv6_address

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -573,7 +577,7 @@ def setup_weight_sync_group(self):
master_address, master_port = self.get_availale_master_addr_port()
world_size = self.config.synchronizer.explorer_world_size + 1
print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).")
explorer = ray.get_actor("explorer")
explorer = ray.get_actor(EXPLORER_NAME)
setup_ref = explorer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
Expand Down
Loading