Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions examples/bots/workflow/bots_reward.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Adapted from Reasoning360: https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py

import concurrent
import contextlib
import math
import re
import resource
from math import isclose
from typing import Optional, Union

Expand Down Expand Up @@ -585,17 +587,25 @@ def should_allow_eval(expr: str):

# @timeout(timeout_seconds=10)
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
are_equal = False
try:
def check_equal():
memory_size = 1024**3
resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size))

expr = f"({ground_truth_normalized})-({given_normalized})"
if should_allow_eval(expr):
sympy_diff = _sympy_parse(expr)
simplified = sympy.simplify(sympy_diff)
if simplified == 0:
are_equal = True
except Exception:
pass
return are_equal
return True
return False

with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
future = executor.submit(check_equal)
try:
return future.result(timeout=10)
except (concurrent.futures.TimeoutError, Exception):
future.cancel()
return False


def split_tuple(expr: str):
Expand Down
5 changes: 5 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def test_trainer(self):
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
actor_kl_metrics = parser.metric_list("actor/kl")
self.assertTrue(len(actor_kl_metrics) > 0)
actor_kl_loss = parser.metric_values("actor/kl_loss")
self.assertEqual(actor_kl_loss[0], 0.0)
critic_kl_metrics = parser.metric_list("critic/kl")
self.assertTrue(len(critic_kl_metrics) > 0)
response_metrics = parser.metric_list("response_length")
Expand All @@ -138,6 +140,9 @@ def test_trainer(self):
self.config.mode = "bench"
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
self.config.explorer.bench_on_latest_checkpoint = False
self.config.buffer.explorer_input.taskset = None
self.config.buffer.explorer_input.tasksets = []
self.config.buffer.trainer_input.experience_buffer = None
self.config.check_and_update()
bench(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
Expand Down
2 changes: 2 additions & 0 deletions trinity/buffer/pipelines/task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


def check_and_run_task_pipeline(config: Config) -> Dict:
if config.mode not in {"explore", "train", "both"}:
return {}
if config.data_processor.task_pipeline is None:
return {}

Expand Down
26 changes: 17 additions & 9 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,8 @@ def _check_interval(self) -> None:
)

def _check_explorer_input(self) -> None:
if self.mode == "train":
# no need to check explorer_input in train mode
if self.mode in {"train", "serve"}:
# no need to check explorer_input in serve mode
return

explorer_input = self.buffer.explorer_input
Expand All @@ -864,12 +864,11 @@ def _check_explorer_input(self) -> None:
raise ValueError("Do not support setting `taskset` and `tasksets` simultaneously!")
explorer_input.tasksets = [explorer_input.taskset]
explorer_input.taskset = None
elif len(explorer_input.tasksets) == 0:
elif self.mode != "bench" and len(explorer_input.tasksets) == 0:
raise ValueError("At least one taskset should be provided in explorer_input!")
tasksets = explorer_input.tasksets

for i, taskset in enumerate(tasksets):
if self.mode != "train" and not taskset.path:
for i, taskset in enumerate(explorer_input.tasksets):
if not taskset.path:
raise ValueError(
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
)
Expand Down Expand Up @@ -914,6 +913,10 @@ def _check_explorer_input(self) -> None:
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)

def _check_trainer_input(self) -> None:
if self.mode == "bench":
# no need to check trainer_input in bench mode
return

trainer_input = self.buffer.trainer_input
experience_buffer = trainer_input.experience_buffer

Expand Down Expand Up @@ -973,7 +976,7 @@ def _default_storage_path(self, storage_type: StorageType, name: str) -> str:
def _check_data_processor(self) -> None:
# check input/output buffers in pipelines
experience_pipeline = self.data_processor.experience_pipeline
if experience_pipeline is not None:
if experience_pipeline is not None and self.mode in {"explore", "both", "serve"}:
if experience_pipeline.save_input and experience_pipeline.input_save_path is None:
experience_pipeline.input_save_path = os.path.join(
self.buffer.cache_dir, "explorer_output.jsonl" # type: ignore[arg-type]
Expand All @@ -983,10 +986,15 @@ def _check_data_processor(self) -> None:
)

task_pipeline = self.data_processor.task_pipeline
if task_pipeline is not None:
if task_pipeline is not None and self.mode in {"explore", "train", "both"}:
if task_pipeline.output is None:
if self.mode != "train":
task_pipeline.output = self.buffer.explorer_input.tasksets[0]
if len(self.buffer.explorer_input.tasksets) > 0:
task_pipeline.output = self.buffer.explorer_input.tasksets[0]
else:
raise ValueError(
"At least one taskset should be provided in explorer_input!"
)
elif self.mode == "train" and self.algorithm.algorithm_type in {"dpo", "sft"}:
task_pipeline.output = self.buffer.trainer_input.experience_buffer
else:
Expand Down
3 changes: 1 addition & 2 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class FSDPConfig:
wrap_policy: WrapPolicy = field(default_factory=WrapPolicy)
fsdp_size: int = -1
forward_prefetch: bool = False
model_dtype: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -163,8 +164,6 @@ class Actor:
clip_ratio_high: Optional[float] = None
entropy_coeff: float = 0.001
use_kl_loss: bool = False
kl_loss_coef: float = 0.0
kl_loss_type: str = "low_var_kl"


@dataclass
Expand Down
9 changes: 7 additions & 2 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self, config: Config):
self.models, self.auxiliary_models = create_inference_models(config)
self.experience_pipeline = self._init_experience_pipeline()
self.taskset = (
TasksetScheduler(explorer_state, config) if self.config.mode != "serve" else None
TasksetScheduler(explorer_state, config)
if self.config.mode not in {"bench", "serve"}
else None
)
self.scheduler = None
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
Expand Down Expand Up @@ -151,7 +153,8 @@ async def prepare(self) -> None:
"""Preparation before running."""
try:
# prepare experience pipeline
await self.experience_pipeline.prepare.remote()
if self.experience_pipeline:
await self.experience_pipeline.prepare.remote()
self.logger.info("Experience pipeline is ready.")
# make sure all rollout models are ready
run_api_ref = [model.run_api_server.remote() for model in self.models]
Expand Down Expand Up @@ -406,6 +409,8 @@ async def is_alive(self) -> bool:

def _init_experience_pipeline(self) -> ray.actor.ActorHandle:
"""Init experience pipeline for the explorer."""
if self.config.mode == "bench":
return None
node_id = ray.get_runtime_context().get_node_id()
return (
ray.remote(ExperiencePipeline)
Expand Down
5 changes: 4 additions & 1 deletion trinity/manager/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle):
self._modules = {module_ref}
self._modules_lock = asyncio.Lock()
asyncio.create_task(self._check_modules())
if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
if (
self.config.mode != "bench"
and self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT
):
asyncio.create_task(self._find_latest_state_dict())

async def add_module(self, module_ref: ray.actor.ActorHandle) -> None:
Expand Down
23 changes: 12 additions & 11 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def __init__(self, config: DictConfig, role: str):
self.config.actor.ppo_micro_batch_size
)

if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
if (
not self.config.actor.use_dynamic_bsz
and self.config.actor.ppo_micro_batch_size_per_gpu is not None
):
assert (
self.config.actor.ppo_mini_batch_size
% self.config.actor.ppo_micro_batch_size_per_gpu
Expand All @@ -181,7 +184,11 @@ def __init__(self, config: DictConfig, role: str):
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"

# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
if (
self._is_ref
and not self.config.ref.log_prob_use_dynamic_bsz
and self.config.ref.log_prob_micro_batch_size is not None
):
self.config.ref.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
Expand Down Expand Up @@ -246,7 +253,7 @@ def _build_model_optimizer( # noqa: C901
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template

torch_dtype = fsdp_config.get("model_dtype", None)
torch_dtype = fsdp_config.model_dtype
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
Expand Down Expand Up @@ -326,9 +333,6 @@ def _build_model_optimizer( # noqa: C901
fused_kernels_backend=fused_kernels_backend,
)

# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)

if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
Expand Down Expand Up @@ -971,7 +975,7 @@ def __init__(self, config):
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size

if self.config.ppo_micro_batch_size_per_gpu is not None:
if not self.config.use_dynamic_bsz and self.config.ppo_micro_batch_size_per_gpu is not None:
assert (
self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
Expand Down Expand Up @@ -1020,7 +1024,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901
if self.rank == 0:
print(f"Critic overriding config {override_config_kwargs}")

torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = self.config.model.fsdp_config.model_dtype or "fp32"
torch_dtype = PrecisionType.to_dtype(torch_dtype)

from transformers import AutoConfig
Expand Down Expand Up @@ -1060,9 +1064,6 @@ def _build_critic_model_optimizer(self, config): # noqa: C901
ulysses_sp_size=self.ulysses_sequence_parallel_size,
)

# some parameters may not in torch_dtype
critic_module.to(torch_dtype)

if config.model.get("enable_gradient_checkpointing", False):
critic_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
Expand Down
72 changes: 39 additions & 33 deletions trinity/trainer/verl/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,44 +233,50 @@ def save_checkpoint( # noqa: C901
json.dump(transformer_config_dict, f, indent=2)

if self.should_save_hf_model or save_as_hf:
# wait for everyone to dump to local
state_dict = self.weight_saver(
self.model,
self.hf_config,
dtype=self.param_dtype,
is_value_model=self.is_value_model,
tie_word_embeddings=self.share_embeddings_and_output_weights,
)
try:
# wait for everyone to dump to local
state_dict = self.weight_saver(
self.model,
self.hf_config,
dtype=self.param_dtype,
is_value_model=self.is_value_model,
tie_word_embeddings=self.share_embeddings_and_output_weights,
)

torch.distributed.barrier()
if self.rank == 0:
# TODO: async save or use mbridge to save hf model
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
import warnings
torch.distributed.barrier()
if self.rank == 0:
# TODO: async save or use mbridge to save hf model
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
import warnings

from accelerate import init_empty_weights
from accelerate import init_empty_weights

with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if "mistral7b-rm" in self.config.model.path:
from transformers import MistralForSequenceClassification
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if "mistral7b-rm" in self.config.model.path:
from transformers import MistralForSequenceClassification

model = MistralForSequenceClassification.from_pretrained(
self.config.model.path
) # use score head instead of lm_head
state_dict["score.weight"] = state_dict["score.weight"]
else:
from transformers import AutoModelForCausalLM
model = MistralForSequenceClassification.from_pretrained(
self.config.model.path
) # use score head instead of lm_head
state_dict["score.weight"] = state_dict["score.weight"]
else:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
self.config.model.path, torch_dtype="auto"
)
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
log_with_rank(
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
rank=self.rank,
logger=logger,
log_only_rank_0=True,
model = AutoModelForCausalLM.from_pretrained(
self.config.model.path, torch_dtype="auto"
)
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
log_with_rank(
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
rank=self.rank,
logger=logger,
log_only_rank_0=True,
)
except Exception:
logger.error(
f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it.",
exc_info=True,
)

ray.get(
Expand Down