Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
get_dcp_whole_state_dict_key,
get_param_prefix,
)
from forge.actors.policy import Policy
from forge.actors.policy import Policy as Generator
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
Expand Down Expand Up @@ -162,6 +162,24 @@ def simple_grpo_loss(
return loss


class Policy(Generator):
"""Policy model to collect data from"""


class Judge(Generator):
"""Judge model for computing rewards"""


async def evaluate(judge, prompt, response):
task = f"""As an expert mathematician, and given the following question:
{prompt}
Is the following response correct?
{response}
Answer with a Positive or Negative."""
verdicts = await judge.generate.route(task)
reward = 1.0 if "Positive" in verdicts[0].text else 0.0
return reward

@dataclass
class RewardActor(ForgeActor):
"""Reward actor that uses a list of scoring functions."""
Expand Down Expand Up @@ -330,14 +348,16 @@ async def main(cfg: DictConfig):
(
dataloader,
policy,
judge,
trainer,
replay_buffer,
compute_advantages,
ref_model,
reward_actor,
#reward_actor,
) = await asyncio.gather(
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
Judge.options(**cfg.services.judge).as_service(**cfg.judge),
RLTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer, loss=simple_grpo_loss
),
Expand All @@ -346,9 +366,9 @@ async def main(cfg: DictConfig):
),
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
reward_functions=[MathReward(), ThinkingReward()]
),
#RewardActor.options(**cfg.services.reward_actor).as_service(
# reward_functions=[MathReward(), ThinkingReward()]
#),
)

print("All services initialized successfully!")
Expand Down Expand Up @@ -403,9 +423,10 @@ async def continuous_rollouts():
episode.response = response.text
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
episode.reward = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
)
#episode.reward = await reward_actor.evaluate_response.route(
# prompt=prompt, response=response.text, target=target
#)
episode.reward = await evaluate(judge, prompt, response.text)

t.step("reward_evaluation")

Expand Down Expand Up @@ -501,11 +522,12 @@ async def continuous_training():
await asyncio.gather(
DatasetActor.shutdown(dataloader),
policy.shutdown(),
judge.shutdown(),
RLTrainer.shutdown(trainer),
ReplayBuffer.shutdown(replay_buffer),
ComputeAdvantages.shutdown(compute_advantages),
ref_model.shutdown(),
reward_actor.shutdown(),
#reward_actor.shutdown(),
)
# TODO - add a global shutdown that implicitly shuts down all services
# and remote allocations
Expand Down
18 changes: 18 additions & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ policy:
temperature: 1.0
top_p: 1.0

# Judge configuration
judge:
engine_config:
model: "Qwen/Qwen3-4B-Thinking-2507"
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
n: 1
guided_decoding: False
max_tokens: ${max_res_tokens}
temperature: 0.0
top_p: 1.0

# Trainer configuration
trainer:
model:
Expand Down Expand Up @@ -118,6 +132,10 @@ services:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 1
with_gpus: true
judge:
procs: ${judge.engine_config.tensor_parallel_size}
num_replicas: 1
with_gpus: true
ref_model:
procs: 1
num_replicas: 1
Expand Down
Loading