@@ -344,40 +344,6 @@ def init_workers(self): # noqa: C901
344344 )
345345 self .resource_pool_to_cls [resource_pool ][str (Role .RefPolicy )] = ref_policy_cls
346346
347- # create a reward model if reward_fn is None
348- # for legacy discriminative reward model, we create a reward model worker here
349- # for reward loop discriminative reward model, we create a reward loop manager here
350- if not self .use_reward_loop :
351- # legacy reward model only handle reward-model based scenario
352- if self .use_rm :
353- # we create a RM here
354- resource_pool = self .resource_pool_manager .get_resource_pool (Role .RewardModel )
355- rm_cls = RayClassWithInitArgs (
356- self .role_worker_mapping [Role .RewardModel ], config = self .config .reward_model
357- )
358- self .resource_pool_to_cls [resource_pool ][str (Role .RewardModel )] = rm_cls
359- else :
360- # reward loop handle hybrid reward scenario (rule, disrm, genrm, ...)
361- # Note: mode is always "async" since sync mode is deprecated
362- can_reward_loop_parallelize = (
363- not self .use_rm or self .config .reward_model .enable_resource_pool
364- )
365- # judge if we can asynchronously parallelize reward model with actor rollout
366- # two condition that we can parallelize reward model with actor rollout:
367- # 1. reward model is not enabled (rule-based reward can parallelize)
368- # 2. reward model is enabled but extra resource pool is enabled
369- # If we cannot parallelize, we should enable synchronous mode here, and launch a reward loop manager here
370- # else for parallelize mode, we launch a reward worker for each rollout worker (in agent loop, not here)
371- if not can_reward_loop_parallelize :
372- from verl .experimental .reward_loop import RewardLoopManager
373-
374- self .config .reward_model .n_gpus_per_node = self .config .trainer .n_gpus_per_node
375- resource_pool = self .resource_pool_manager .get_resource_pool (Role .RewardModel )
376- self .reward_loop_manager = RewardLoopManager (
377- config = self .config ,
378- rm_resource_pool = resource_pool ,
379- )
380-
381347 # initialize WorkerGroup
382348 # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
383349 # you should not use `create_colocated_worker_cls`.
@@ -439,12 +405,6 @@ def init_workers(self): # noqa: C901
439405 assert str (Role .ActorRolloutRef ) in all_wg , f"{ all_wg .keys ()= } "
440406 self .ref_policy_wg = all_wg [str (Role .ActorRolloutRef )]
441407
442- self .rm_wg = None
443- # initalization of rm_wg will be deprecated in the future
444- if self .use_rm and not self .use_reward_loop :
445- self .rm_wg = all_wg [str (Role .RewardModel )]
446- self .rm_wg .init_model ()
447-
448408 # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
449409 self .actor_rollout_wg = all_wg [str (actor_role )]
450410 self .actor_rollout_wg .init_model ()
@@ -515,13 +475,14 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901
515475 "bypass_mode" , False
516476 )
517477 if bypass_recomputing_logprobs : # Use `rollout_log_probs`
518- from verl .trainer .ppo .rollout_corr_helper import apply_bypass_mode
478+ if "rollout_log_probs" in batch .batch :
479+ from verl .trainer .ppo .rollout_corr_helper import apply_bypass_mode
519480
520- apply_bypass_mode (
521- batch = batch ,
522- rollout_corr_config = rollout_corr_config ,
523- policy_loss_config = self .config .actor_rollout_ref .actor .policy_loss ,
524- )
481+ apply_bypass_mode (
482+ batch = batch ,
483+ rollout_corr_config = rollout_corr_config ,
484+ policy_loss_config = self .config .actor_rollout_ref .actor .policy_loss ,
485+ )
525486 else : # Recompute old_log_probs TODO: to be check
526487 if (batch .meta_info ["model_versions" ] != self .global_steps - 1 ).any ():
527488 self .logger .warning (
@@ -551,8 +512,6 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901
551512
552513 metrics .update (calculate_debug_metrics (batch ))
553514
554- assert "old_log_probs" in batch .batch , f'"old_log_prob" not in { batch .batch .keys ()= } '
555-
556515 if self .algorithm .use_reference : # ref_logprob may not be used
557516 # compute reference log_prob
558517 with marked_timer (str (Role .RefPolicy ), timing_raw , color = "olive" ):
0 commit comments