Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def test_trainer(self):
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
actor_kl_metrics = parser.metric_list("actor/kl")
self.assertTrue(len(actor_kl_metrics) > 0)
actor_kl_loss = parser.metric_values("actor/kl_loss")
self.assertEqual(actor_kl_loss[0], 0.0)
critic_kl_metrics = parser.metric_list("critic/kl")
self.assertTrue(len(critic_kl_metrics) > 0)
response_metrics = parser.metric_list("response_length")
Expand Down
3 changes: 2 additions & 1 deletion trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def bench(config: Config) -> None:

def explore(config: Config) -> None:
"""Run explorer."""
check_and_run_task_pipeline(config)
try:
explorer = Explorer.get_actor(config)
ray.get(explorer.prepare.remote())
Expand Down Expand Up @@ -81,6 +82,7 @@ def both(config: Config) -> None:
the latest step. The specific number of experiences may vary for different
algorithms and tasks.
"""
check_and_run_task_pipeline(config)
try:
explorer = Explorer.get_actor(config)
trainer = Trainer.get_actor(config)
Expand Down Expand Up @@ -151,7 +153,6 @@ def run_stage(config: Config) -> None:
)
pprint(config)
try:
check_and_run_task_pipeline(config)
MODE_MAP[config.mode](config)
finally:
if config.monitor.enable_ray_timeline:
Expand Down
15 changes: 9 additions & 6 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,8 @@ def _check_interval(self) -> None:
)

def _check_explorer_input(self) -> None:
if self.mode == "train":
# no need to check explorer_input in train mode
if self.mode in {"train", "bench", "serve"}:
# no need to check explorer_input in train/bench/serve mode
return

explorer_input = self.buffer.explorer_input
Expand All @@ -866,9 +866,8 @@ def _check_explorer_input(self) -> None:
explorer_input.taskset = None
elif len(explorer_input.tasksets) == 0:
raise ValueError("At least one taskset should be provided in explorer_input!")
tasksets = explorer_input.tasksets

for i, taskset in enumerate(tasksets):
for i, taskset in enumerate(explorer_input.tasksets):
if self.mode != "train" and not taskset.path:
raise ValueError(
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
Expand Down Expand Up @@ -914,6 +913,10 @@ def _check_explorer_input(self) -> None:
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)

def _check_trainer_input(self) -> None:
if self.mode in {"explore", "bench", "serve"}:
# no need to check trainer_input in explore/bench/serve mode
return

trainer_input = self.buffer.trainer_input
experience_buffer = trainer_input.experience_buffer

Expand Down Expand Up @@ -973,7 +976,7 @@ def _default_storage_path(self, storage_type: StorageType, name: str) -> str:
def _check_data_processor(self) -> None:
# check input/output buffers in pipelines
experience_pipeline = self.data_processor.experience_pipeline
if experience_pipeline is not None:
if experience_pipeline is not None and self.mode in {"explore", "both", "serve"}:
if experience_pipeline.save_input and experience_pipeline.input_save_path is None:
experience_pipeline.input_save_path = os.path.join(
self.buffer.cache_dir, "explorer_output.jsonl" # type: ignore[arg-type]
Expand All @@ -983,7 +986,7 @@ def _check_data_processor(self) -> None:
)

task_pipeline = self.data_processor.task_pipeline
if task_pipeline is not None:
if task_pipeline is not None and self.mode in {"explore", "both"}:
if task_pipeline.output is None:
if self.mode != "train":
task_pipeline.output = self.buffer.explorer_input.tasksets[0]
Expand Down
3 changes: 1 addition & 2 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class FSDPConfig:
wrap_policy: WrapPolicy = field(default_factory=WrapPolicy)
fsdp_size: int = -1
forward_prefetch: bool = False
model_dtype: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -163,8 +164,6 @@ class Actor:
clip_ratio_high: Optional[float] = None
entropy_coeff: float = 0.001
use_kl_loss: bool = False
kl_loss_coef: float = 0.0
kl_loss_type: str = "low_var_kl"


@dataclass
Expand Down
9 changes: 7 additions & 2 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self, config: Config):
self.models, self.auxiliary_models = create_inference_models(config)
self.experience_pipeline = self._init_experience_pipeline()
self.taskset = (
TasksetScheduler(explorer_state, config) if self.config.mode != "serve" else None
TasksetScheduler(explorer_state, config)
if self.config.mode not in {"bench", "serve"}
else None
)
self.scheduler = None
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
Expand Down Expand Up @@ -151,7 +153,8 @@ async def prepare(self) -> None:
"""Preparation before running."""
try:
# prepare experience pipeline
await self.experience_pipeline.prepare.remote()
if self.experience_pipeline:
await self.experience_pipeline.prepare.remote()
self.logger.info("Experience pipeline is ready.")
# make sure all rollout models are ready
run_api_ref = [model.run_api_server.remote() for model in self.models]
Expand Down Expand Up @@ -406,6 +409,8 @@ async def is_alive(self) -> bool:

def _init_experience_pipeline(self) -> ray.actor.ActorHandle:
"""Init experience pipeline for the explorer."""
if self.config.mode == "bench":
return None
node_id = ray.get_runtime_context().get_node_id()
return (
ray.remote(ExperiencePipeline)
Expand Down
5 changes: 4 additions & 1 deletion trinity/manager/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle):
self._modules = {module_ref}
self._modules_lock = asyncio.Lock()
asyncio.create_task(self._check_modules())
if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
if (
self.config.mode != "bench"
and self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT
):
asyncio.create_task(self._find_latest_state_dict())

async def add_module(self, module_ref: ray.actor.ActorHandle) -> None:
Expand Down
23 changes: 12 additions & 11 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def __init__(self, config: DictConfig, role: str):
self.config.actor.ppo_micro_batch_size
)

if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
if (
not self.config.actor.use_dynamic_bsz
and self.config.actor.ppo_micro_batch_size_per_gpu is not None
):
assert (
self.config.actor.ppo_mini_batch_size
% self.config.actor.ppo_micro_batch_size_per_gpu
Expand All @@ -181,7 +184,11 @@ def __init__(self, config: DictConfig, role: str):
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"

# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
if (
self._is_ref
and not self.config.ref.log_prob_use_dynamic_bsz
and self.config.ref.log_prob_micro_batch_size is not None
):
self.config.ref.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
Expand Down Expand Up @@ -246,7 +253,7 @@ def _build_model_optimizer( # noqa: C901
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template

torch_dtype = fsdp_config.get("model_dtype", None)
torch_dtype = fsdp_config.model_dtype
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
Expand Down Expand Up @@ -326,9 +333,6 @@ def _build_model_optimizer( # noqa: C901
fused_kernels_backend=fused_kernels_backend,
)

# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)

if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
Expand Down Expand Up @@ -971,7 +975,7 @@ def __init__(self, config):
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size

if self.config.ppo_micro_batch_size_per_gpu is not None:
if not self.config.use_dynamic_bsz and self.config.ppo_micro_batch_size_per_gpu is not None:
assert (
self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
Expand Down Expand Up @@ -1020,7 +1024,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901
if self.rank == 0:
print(f"Critic overriding config {override_config_kwargs}")

torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = self.config.model.fsdp_config.model_dtype or "fp32"
torch_dtype = PrecisionType.to_dtype(torch_dtype)

from transformers import AutoConfig
Expand Down Expand Up @@ -1060,9 +1064,6 @@ def _build_critic_model_optimizer(self, config): # noqa: C901
ulysses_sp_size=self.ulysses_sequence_parallel_size,
)

# some parameters may not in torch_dtype
critic_module.to(torch_dtype)

if config.model.get("enable_gradient_checkpointing", False):
critic_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
Expand Down
72 changes: 39 additions & 33 deletions trinity/trainer/verl/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,44 +233,50 @@ def save_checkpoint( # noqa: C901
json.dump(transformer_config_dict, f, indent=2)

if self.should_save_hf_model or save_as_hf:
# wait for everyone to dump to local
state_dict = self.weight_saver(
self.model,
self.hf_config,
dtype=self.param_dtype,
is_value_model=self.is_value_model,
tie_word_embeddings=self.share_embeddings_and_output_weights,
)
try:
# wait for everyone to dump to local
state_dict = self.weight_saver(
self.model,
self.hf_config,
dtype=self.param_dtype,
is_value_model=self.is_value_model,
tie_word_embeddings=self.share_embeddings_and_output_weights,
)

torch.distributed.barrier()
if self.rank == 0:
# TODO: async save or use mbridge to save hf model
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
import warnings
torch.distributed.barrier()
if self.rank == 0:
# TODO: async save or use mbridge to save hf model
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
import warnings

from accelerate import init_empty_weights
from accelerate import init_empty_weights

with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if "mistral7b-rm" in self.config.model.path:
from transformers import MistralForSequenceClassification
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if "mistral7b-rm" in self.config.model.path:
from transformers import MistralForSequenceClassification

model = MistralForSequenceClassification.from_pretrained(
self.config.model.path
) # use score head instead of lm_head
state_dict["score.weight"] = state_dict["score.weight"]
else:
from transformers import AutoModelForCausalLM
model = MistralForSequenceClassification.from_pretrained(
self.config.model.path
) # use score head instead of lm_head
state_dict["score.weight"] = state_dict["score.weight"]
else:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
self.config.model.path, torch_dtype="auto"
)
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
log_with_rank(
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
rank=self.rank,
logger=logger,
log_only_rank_0=True,
model = AutoModelForCausalLM.from_pretrained(
self.config.model.path, torch_dtype="auto"
)
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
log_with_rank(
f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}",
rank=self.rank,
logger=logger,
log_only_rank_0=True,
)
except Exception:
logger.error(
f"Failed to save Huggingface model to {local_path}, you can try to set `use_mbridge=true` to save it.",
exc_info=True,
)

ray.get(
Expand Down