From 119060c3c7171f4f3728ab9f938f805cf1d05011 Mon Sep 17 00:00:00 2001 From: pbontrager Date: Tue, 7 Oct 2025 20:08:09 +0000 Subject: [PATCH] added judge changes --- apps/grpo/main.py | 40 ++++++++++++++++++++++++++++++--------- apps/grpo/qwen3_1_7b.yaml | 18 ++++++++++++++++++ 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f4c7988bb..34ff54cd2 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -20,7 +20,7 @@ get_dcp_whole_state_dict_key, get_param_prefix, ) -from forge.actors.policy import Policy +from forge.actors.policy import Policy as Generator from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import RLTrainer @@ -162,6 +162,24 @@ def simple_grpo_loss( return loss +class Policy(Generator): + """Policy model to collect data from""" + + +class Judge(Generator): + """Judge model for computing rewards""" + + +async def evaluate(judge, prompt, response): + task = f"""As an expert mathematician, and given the following question: + {prompt} + Is the following response correct? + {response} + Answer with a Positive or Negative.""" + verdicts = await judge.generate.route(task) + reward = 1.0 if "Positive" in verdicts[0].text else 0.0 + return reward + @dataclass class RewardActor(ForgeActor): """Reward actor that uses a list of scoring functions.""" @@ -330,14 +348,16 @@ async def main(cfg: DictConfig): ( dataloader, policy, + judge, trainer, replay_buffer, compute_advantages, ref_model, - reward_actor, + #reward_actor, ) = await asyncio.gather( DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), + Judge.options(**cfg.services.judge).as_service(**cfg.judge), RLTrainer.options(**cfg.actors.trainer).as_actor( **cfg.trainer, loss=simple_grpo_loss ), @@ -346,9 +366,9 @@ async def main(cfg: DictConfig): ), ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(), ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), - RewardActor.options(**cfg.services.reward_actor).as_service( - reward_functions=[MathReward(), ThinkingReward()] - ), + #RewardActor.options(**cfg.services.reward_actor).as_service( + # reward_functions=[MathReward(), ThinkingReward()] + #), ) print("All services initialized successfully!") @@ -403,9 +423,10 @@ async def continuous_rollouts(): episode.response = response.text input_ids[i, :max_req_tokens] = episode.request_tensor input_ids[i, max_req_tokens:] = episode.response_tensor - episode.reward = await reward_actor.evaluate_response.route( - prompt=prompt, response=response.text, target=target - ) + #episode.reward = await reward_actor.evaluate_response.route( + # prompt=prompt, response=response.text, target=target + #) + episode.reward = await evaluate(judge, prompt, response.text) t.step("reward_evaluation") @@ -501,11 +522,12 @@ async def continuous_training(): await asyncio.gather( DatasetActor.shutdown(dataloader), policy.shutdown(), + judge.shutdown(), RLTrainer.shutdown(trainer), ReplayBuffer.shutdown(replay_buffer), ComputeAdvantages.shutdown(compute_advantages), ref_model.shutdown(), - reward_actor.shutdown(), + #reward_actor.shutdown(), ) # TODO - add a global shutdown that implicitly shuts down all services # and remote allocations diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 53eec5cfb..3f8de13f7 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -43,6 +43,20 @@ policy: temperature: 1.0 top_p: 1.0 +# Judge configuration +judge: + engine_config: + model: "Qwen/Qwen3-4B-Thinking-2507" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_config: + n: 1 + guided_decoding: False + max_tokens: ${max_res_tokens} + temperature: 0.0 + top_p: 1.0 + # Trainer configuration trainer: model: @@ -118,6 +132,10 @@ services: procs: ${policy.engine_config.tensor_parallel_size} num_replicas: 1 with_gpus: true + judge: + procs: ${judge.engine_config.tensor_parallel_size} + num_replicas: 1 + with_gpus: true ref_model: procs: 1 num_replicas: 1