Skip to content

Commit e99d23a

Browse files
committed
apply reviews and fix trainer_test
1 parent df9cc27 commit e99d23a

File tree

9 files changed

+15
-96
lines changed

9 files changed

+15
-96
lines changed

.github/workflows/docker/docker-compose.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
services:
22
trinity-node-1:
3-
image: trinity-rft-unittest:20251225
3+
image: trinity-rft-unittest:20260115
44
pull_policy: never
55
command: bash -c "source /opt/venv/bin/activate && uv pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block"
66
environment:
@@ -30,7 +30,7 @@ services:
3030
capabilities: [gpu]
3131

3232
trinity-node-2:
33-
image: trinity-rft-unittest:20251225
33+
image: trinity-rft-unittest:20260115
3434
pull_policy: never
3535
command: bash -c "source /opt/venv/bin/activate && uv pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block"
3636
environment:

trinity/common/models/vllm_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class WorkerExtension:
1414
def apply_patches(self):
1515
"""Apply necessary patches to vLLM."""
16-
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
16+
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
1717

1818
patch_vllm_moe_model_weight_loader(self.model_runner.model)
1919
patch_vllm_prompt_logprobs(self.model_runner)

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ def save_state_dict(
419419
self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg
420420
):
421421
self._save_model(local_path, global_step)
422+
self._save_tokenizer(local_path, global_step)
422423
ray.get(
423424
self.checkpoint_monitor.register_thread_count.remote(
424425
global_step, state_dict_thread_count=1

trinity/trainer/verl/fsdp_workers.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
get_device_name,
5454
get_nccl_backend,
5555
get_torch_device,
56-
set_expandable_segments,
5756
)
5857
from verl.utils.flops_counter import FlopsCounter
5958
from verl.utils.fs import copy_to_local
@@ -75,7 +74,6 @@
7574
)
7675
from verl.utils.import_utils import import_external_libs
7776
from verl.utils.logger import log_with_rank
78-
from verl.utils.memory_utils import aggressive_empty_cache
7977
from verl.utils.profiler import (
8078
DistProfiler,
8179
DistProfilerExtension,
@@ -640,24 +638,6 @@ def _build_model_optimizer( # noqa: C901
640638

641639
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
642640

643-
async def trainer_mode(self): # TODO: check this
644-
"""Context switch hybridengine to trainer mode."""
645-
# if self.config.rollout.free_cache_engine:
646-
# log_gpu_memory_usage("Before rollout offload", logger=logger)
647-
# await self.rollout.release()
648-
# log_gpu_memory_usage("After rollout offload", logger=logger)
649-
650-
self.actor_module_fsdp.train()
651-
652-
# add empty cache after each compute
653-
aggressive_empty_cache(force_sync=True)
654-
655-
set_expandable_segments(True)
656-
657-
# restore random states
658-
self.gen_random_states = get_torch_device().get_rng_state()
659-
get_torch_device().set_rng_state(self.torch_random_states)
660-
661641
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
662642
def init_model(self):
663643
from trinity.trainer.verl.dp_actor import DataParallelPPOActor
@@ -1604,7 +1584,7 @@ def update_critic(self, data: DataProto):
16041584
)
16051585

16061586
lr = self.critic_lr_scheduler.get_last_lr()[0]
1607-
metrics["critic/lr"] = lr
1587+
metrics["critic/lr"] = lr.item() if torch.is_tensor(lr) else lr
16081588
self.critic_lr_scheduler.step()
16091589

16101590
output = DataProto(batch=None, meta_info={"metrics": metrics})

trinity/trainer/verl/megatron_actor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def logits_processor(logits, label, label_mask):
363363
logits.div_(temperature)
364364
ret = {}
365365
if calculate_entropy:
366-
# The veRL fix consumes more GPU memory than our implementation
366+
# The veRL fix consumes more GPU memory than our implementation
367367
# (.clone() v.s. monkey patch on megatron function);
368368
# therefore, we have temporarily commented out the veRL fix.
369369
# logits_bak = logits.clone()

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def save_state_dict( # noqa: C901
396396

397397
local_path = local_mkdir_safe(local_path)
398398
self._save_state_dict(local_path, global_step)
399+
self._save_tokenizer(local_path, global_step)
399400
ray.get(
400401
self.checkpoint_monitor.register_thread_count.remote(
401402
global_step, state_dict_thread_count=1

trinity/trainer/verl/megatron_workers.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -745,28 +745,6 @@ def upload_state_dict(self, trainer_step: int):
745745
def set_algorithm(self, algo_config: AlgorithmConfig):
746746
self.actor.set_algorithm(algo_config)
747747

748-
async def trainer_mode(self):
749-
"""Context switch hybridengine to trainer mode."""
750-
# if self.config.rollout.free_cache_engine:
751-
# log_gpu_memory_usage("Before rollout offload", logger=logger)
752-
# await self.rollout.release()
753-
# log_gpu_memory_usage("After rollout offload", logger=logger)
754-
755-
for model in self.actor.actor_module:
756-
model.train()
757-
# add empty cache after each compute
758-
aggressive_empty_cache(force_sync=True)
759-
760-
# FIXME(@wuxibin): megatron+sglang failed with `expandable_segments:True` in ci,
761-
# can't reproduce it in dev environment, temporary disable it.
762-
# https://github.com/volcengine/verl/actions/runs/17382936845/job/49344264323?pr=3285
763-
if os.environ.get("MEGATRON_CI_DISABLE_EXPANDABLE_SEGMENTS", "0") == "0":
764-
set_expandable_segments(True)
765-
766-
# restore random states
767-
self.gen_random_states = get_torch_device().get_rng_state()
768-
get_torch_device().set_rng_state(self.torch_random_states)
769-
770748
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
771749
@GPUMemoryLogger(role="update_actor", logger=logger)
772750
@DistProfiler.annotate(color="red", role="actor_update")

trinity/trainer/verl/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def to_data_proto(
100100
)
101101
else:
102102
raise ValueError("Custom fields are not consistent across experiences.")
103-
meta_info = {"model_versions": np.array([exp.info["model_version"] for exp in experiences])}
103+
meta_info = {"model_versions": np.array([exp.info.get("model_version", 0) for exp in experiences])}
104104
return DataProto.from_single_dict(batch_dict, meta_info=meta_info)
105105

106106

trinity/trainer/verl_trainer.py

Lines changed: 7 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -344,40 +344,6 @@ def init_workers(self): # noqa: C901
344344
)
345345
self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
346346

347-
# create a reward model if reward_fn is None
348-
# for legacy discriminative reward model, we create a reward model worker here
349-
# for reward loop discriminative reward model, we create a reward loop manager here
350-
if not self.use_reward_loop:
351-
# legacy reward model only handle reward-model based scenario
352-
if self.use_rm:
353-
# we create a RM here
354-
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
355-
rm_cls = RayClassWithInitArgs(
356-
self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model
357-
)
358-
self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
359-
else:
360-
# reward loop handle hybrid reward scenario (rule, disrm, genrm, ...)
361-
# Note: mode is always "async" since sync mode is deprecated
362-
can_reward_loop_parallelize = (
363-
not self.use_rm or self.config.reward_model.enable_resource_pool
364-
)
365-
# judge if we can asynchronously parallelize reward model with actor rollout
366-
# two condition that we can parallelize reward model with actor rollout:
367-
# 1. reward model is not enabled (rule-based reward can parallelize)
368-
# 2. reward model is enabled but extra resource pool is enabled
369-
# If we cannot parallelize, we should enable synchronous mode here, and launch a reward loop manager here
370-
# else for parallelize mode, we launch a reward worker for each rollout worker (in agent loop, not here)
371-
if not can_reward_loop_parallelize:
372-
from verl.experimental.reward_loop import RewardLoopManager
373-
374-
self.config.reward_model.n_gpus_per_node = self.config.trainer.n_gpus_per_node
375-
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
376-
self.reward_loop_manager = RewardLoopManager(
377-
config=self.config,
378-
rm_resource_pool=resource_pool,
379-
)
380-
381347
# initialize WorkerGroup
382348
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
383349
# you should not use `create_colocated_worker_cls`.
@@ -439,12 +405,6 @@ def init_workers(self): # noqa: C901
439405
assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}"
440406
self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)]
441407

442-
self.rm_wg = None
443-
# initalization of rm_wg will be deprecated in the future
444-
if self.use_rm and not self.use_reward_loop:
445-
self.rm_wg = all_wg[str(Role.RewardModel)]
446-
self.rm_wg.init_model()
447-
448408
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
449409
self.actor_rollout_wg = all_wg[str(actor_role)]
450410
self.actor_rollout_wg.init_model()
@@ -515,13 +475,14 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901
515475
"bypass_mode", False
516476
)
517477
if bypass_recomputing_logprobs: # Use `rollout_log_probs`
518-
from verl.trainer.ppo.rollout_corr_helper import apply_bypass_mode
478+
if "rollout_log_probs" in batch.batch:
479+
from verl.trainer.ppo.rollout_corr_helper import apply_bypass_mode
519480

520-
apply_bypass_mode(
521-
batch=batch,
522-
rollout_corr_config=rollout_corr_config,
523-
policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
524-
)
481+
apply_bypass_mode(
482+
batch=batch,
483+
rollout_corr_config=rollout_corr_config,
484+
policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
485+
)
525486
else: # Recompute old_log_probs TODO: to be check
526487
if (batch.meta_info["model_versions"] != self.global_steps - 1).any():
527488
self.logger.warning(
@@ -551,8 +512,6 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901
551512

552513
metrics.update(calculate_debug_metrics(batch))
553514

554-
assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}'
555-
556515
if self.algorithm.use_reference: # ref_logprob may not be used
557516
# compute reference log_prob
558517
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):

0 commit comments

Comments
 (0)