Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/template/verl_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ actor_rollout_ref:
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
checkpoint:
contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
Expand Down Expand Up @@ -72,6 +72,8 @@ critic:
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
checkpoint:
contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space

trainer:
balance_batch: True
Expand Down
38 changes: 33 additions & 5 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,40 @@


class SampleStrategy(ABC):
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs) -> None:
self.pad_token_id = buffer_config.pad_token_id
self.trainer_type = trainer_type

@abstractmethod
def sample(self, step: int) -> Tuple[Any, Dict, List]:
"""Sample experiences from buffer.
"""Sample data from buffer.

Args:
step (`int`): The step number of current step.

Returns:
`Any`: The sampled experiences.
`Any`: The sampled data.
`Dict`: Metrics for logging.
`List`: Representative experiences for logging.
`List`: Representative data for logging.
"""

# Experimental API
@abstractmethod
def warmup_state(self, step: int) -> Tuple[bool, bool]:
"""Check the warmup state of the current step.

Args:
step (`int`): The step number of current step.

Returns:
`bool`: Current step is in warmup or not.
`bool`: Warmup is finished on this step or not.
"""

@classmethod
@abstractmethod
def default_args(cls) -> dict:
return {}
"""Get the default arguments of the sample strategy."""


@SAMPLE_STRATEGY.register_module("warmup")
Expand Down Expand Up @@ -70,6 +84,13 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")

def warmup_state(self, step: int) -> Tuple[bool, bool]:
return step <= self.sft_warmup_steps, step == self.sft_warmup_steps

@classmethod
def default_args(cls) -> dict:
return {}


@SAMPLE_STRATEGY.register_module("default")
class DefaultSampleStrategy(SampleStrategy):
Expand All @@ -93,6 +114,13 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")

def warmup_state(self, step: int) -> Tuple[bool, bool]:
return False, False

@classmethod
def default_args(cls) -> dict:
return {}


@SAMPLE_STRATEGY.register_module("dpo")
class DPOSampleStrategy(WarmupSampleStrategy):
Expand Down
90 changes: 32 additions & 58 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import os
import sys
import traceback
from pathlib import Path
from pprint import pprint

Expand All @@ -18,44 +19,41 @@

def bench(config: Config) -> None:
"""Evaluate model."""
explorer = Explorer.remote(config)
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.benchmark.remote())
logger.info("Benchmark finished.")
ray.get(explorer.shutdown.remote())
except Exception as e:
logger.error(f"Benchmark failed: {e}")
raise e
except Exception:
error_msg = traceback.format_exc()
logger.error(f"Benchmark failed:\n{error_msg}")


def explore(config: Config) -> None:
"""Run explorer."""
explorer = Explorer.remote(config)
try:
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
ray.get(explorer.explore.remote())
logger.info("Explore finished.")
ray.get(explorer.shutdown.remote())
except Exception as e:
logger.error(f"Explore failed: {e}")
raise e
except Exception:
error_msg = traceback.format_exc()
logger.error(f"Explorer failed:\n{error_msg}")


def train(config: Config) -> None:
"""Run trainer."""

trainer = Trainer.remote(config)
ray.get(trainer.prepare.remote())

try:
trainer = ray.remote(Trainer).options(name="trainer").remote(config)
ray.get(trainer.prepare.remote())
ray.get(trainer.sync_weight.remote())
ray.get(trainer.train.remote())
logger.info("Train finished.")
ray.get(trainer.shutdown.remote())
except Exception as e:
logger.error(f"Train failed {e}.")
raise e
except Exception:
error_msg = traceback.format_exc()
logger.error(f"Trainer failed:\n{error_msg}")


def both(config: Config) -> None:
Expand All @@ -68,54 +66,30 @@ def both(config: Config) -> None:
the latest step. The specific number of experiences may vary for different
algorithms and tasks.
"""
explorer = Explorer.remote(config)
trainer = Trainer.remote(config)
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
trainer = ray.remote(Trainer).options(name="trainer").remote(config)
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
logger.info("Setup explorer and trainer finished.")
ray.get(
[
explorer.prepare.remote(),
trainer.prepare.remote(),
]
)
# sync weight before training start
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])

while True:
try:
ref_explore = explorer.explore_one_period.remote()
ref_train = trainer.train_one_period.remote()
explore_continue, explore_step_num = ray.get(ref_explore)
train_continue, train_step_num = ray.get(ref_train)
if not explore_continue:
# If explore finished, the trainer may not have enough experiences to continue,
# which will cause the trainer be blocked. So we stop the training process
# immediately.
# TODO: use a more elegant way to stop the training process.
logger.info("Explorer finished, stopping...")
break
if not train_continue:
logger.info("Trainer finished, stopping...")
break
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
logger.info("Model weight synchronized.")
except Exception as e:
logger.error(e)
logger.error("Training stopped due to exception.")
raise e
if explore_step_num % config.explorer.eval_interval == 0:
try:
ray.get(explorer.eval.remote())
logger.info("Evaluation finished.")
except Exception as e:
logger.error(e)
logger.error("Evaluation failed.")
raise e
ray.get(explorer.flush_log.remote(step=explore_step_num))
ray.get(trainer.flush_log.remote(step=train_step_num))

ray.get(explorer.shutdown.remote())
ray.get(trainer.shutdown.remote())
ray.get(
[
explorer.sync_weight.remote(),
trainer.sync_weight.remote(),
]
)
_, _ = ray.wait(
[
explorer.explore.remote(),
trainer.train.remote(),
],
num_returns=1,
)
explorer.shutdown.remote(),
trainer.shutdown.remote(),


def activate_data_module(data_workflow_url: str, config_path: str):
Expand Down
4 changes: 3 additions & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,10 @@ class SynchronizerConfig:
sync_method: SyncMethod = SyncMethod.NCCL
# sync weights every `sync_interval` steps
sync_interval: int = 1
# allow explorer to run `sync_offset` steps before sync
sync_offset: int = 0
# waiting for `sync_timeout` seconds before timeout in `nccl` method
sync_timeout: int = 1200
sync_timeout: int = 1800
# wait for the lastest checkpoint to be ready # TODO: to be used
wait_for_checkpoint: bool = False

Expand Down
1 change: 0 additions & 1 deletion trinity/common/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None
iteration = f.read().strip()
return os.path.join(checkpoint_path, f"global_step_{iteration}")
else:
logger.error(f"No iteration file found in {checkpoint_path}")
raise FileNotFoundError(f"No iteration file found in {checkpoint_path}")
else:
# load specific iteration checkpoint
Expand Down
12 changes: 4 additions & 8 deletions trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,9 @@ async def _collective_rpc(

async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
"""Sync model weights to vLLM."""
if self.state_dict_meta is None:
self.state_dict_meta = update_weight_args_list
for args in self.state_dict_meta:
await self._collective_rpc("update_weight", args=args)
if update_weight_args_list is not None:
await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
await self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
return True
Expand All @@ -287,7 +286,6 @@ async def init_process_group(
update_with_checkpoint: bool = True,
state_dict_meta: dict = None,
):
self.state_dict_meta = state_dict_meta
return await self._collective_rpc(
"init_process_group",
args=(
Expand All @@ -299,12 +297,10 @@ async def init_process_group(
backend,
timeout,
update_with_checkpoint,
state_dict_meta,
),
)

async def update_weight(self, name, dtype, shape, empty_cache=False):
return await self._collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))

async def run_api_server(self):
"""Run the OpenAI API server in a Ray actor.

Expand Down
13 changes: 4 additions & 9 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def init_process_group(
update_with_checkpoint: bool = True,
state_dict_meta: dict = None,
):
self.state_dict_meta = state_dict_meta
return self.llm.collective_rpc(
"init_process_group",
args=(
Expand All @@ -112,12 +111,10 @@ def init_process_group(
backend,
timeout,
update_with_checkpoint,
state_dict_meta,
),
)

def update_weight(self, name, dtype, shape, empty_cache=False):
return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))

def reset_prefix_cache(self):
self.llm.llm_engine.reset_prefix_cache()

Expand Down Expand Up @@ -279,11 +276,9 @@ def has_api_server(self) -> bool:

def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
"""Sync model weights to vLLM."""
if self.state_dict_meta is None:
self.state_dict_meta = update_weight_args_list
with self.lock:
for args in self.state_dict_meta:
self.llm.collective_rpc("update_weight", args=args)
if update_weight_args_list is not None:
self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
return True
Expand Down
Loading