Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e6b7692
first changes
pbontrager Aug 29, 2025
a95a001
core updates
pbontrager Aug 31, 2025
3ba0df6
batch update
pbontrager Sep 1, 2025
3e32264
fix typo
pbontrager Sep 2, 2025
e4723bb
Merge branch 'main' into ungroup
pbontrager Sep 2, 2025
5a17c8b
Merge branch 'main' into ungroup
pbontrager Sep 2, 2025
52028a5
missing import
pbontrager Sep 2, 2025
e2a3a68
debug merge
pbontrager Sep 2, 2025
2cf9d00
more fixes
pbontrager Sep 4, 2025
b85320c
Remove dtype warnings
joecummings Sep 4, 2025
f7626ce
Stub
joecummings Sep 4, 2025
bf31587
It runs
joecummings Sep 4, 2025
53c8c89
Add in ref
joecummings Sep 4, 2025
f494949
<Replace this line with a title. Use 1 line only, 67 chars or less>
joecummings Sep 4, 2025
a13a1ac
Pass linting?
joecummings Sep 4, 2025
833a6b6
Remove extraneous 'calculations'
joecummings Sep 4, 2025
0acbe4a
Stub out push weights
joecummings Sep 4, 2025
7d05aad
Remove tokenizer, add back in formatting
joecummings Sep 4, 2025
3c880dd
Cleanup
joecummings Sep 4, 2025
8796fa1
Working w/ weight sync
joecummings Sep 4, 2025
75447d9
stub
joecummings Sep 5, 2025
2838937
Merge remote-tracking branch 'upstream/main' into working-updates
joecummings Sep 8, 2025
3120100
Queue while updating weights
joecummings Sep 8, 2025
8f4bda1
Cleanup
joecummings Sep 10, 2025
7825255
Make sd conversion happen on push
joecummings Sep 11, 2025
b511fe3
Sum over train_step valuemesh
joecummings Sep 11, 2025
9b46a77
Merge remote-tracking branch 'upstream/main' into working-updates
joecummings Sep 11, 2025
1a6d6df
Update config
joecummings Sep 11, 2025
e31f815
Loss updates
joecummings Sep 11, 2025
55c32be
Updated rewards (just played around a bit)
joecummings Sep 11, 2025
b74a47c
Update rewards
joecummings Sep 11, 2025
14d6354
Fix last math reward test
joecummings Sep 11, 2025
8fa4451
Async by 1
joecummings Sep 11, 2025
bdd03a8
Seg fault
joecummings Sep 12, 2025
7eedc91
Make torchstore actually work!
joecummings Sep 12, 2025
4044087
Last updates
joecummings Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 57 additions & 41 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import asyncio
import logging
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Optional
Expand All @@ -21,12 +21,11 @@
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from torch import nn
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM, push_state_dict
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
Expand Down Expand Up @@ -121,7 +120,7 @@ def new_group(
target: Any = None,
):
episodes = []
for i in range(group_size):
for _ in range(group_size):
episodes.append(
Episode(
episode_id=str(uuid.uuid4()),
Expand All @@ -145,6 +144,8 @@ class Trainer(ForgeActor):
beta: float = 0.1
epsilon: float = 0.1
device: torch.device | None = None
store: MultiProcessStore | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the current recommended way to use store now. You should just call it as a singleton inside of the trainer

state_dict_key: str = "model_state_dict"

@endpoint
def setup(self):
Expand Down Expand Up @@ -208,11 +209,19 @@ async def train_step(self, batch: list[Episode]):

self.optimizer.step()

return {"loss": loss.item()}
return loss.item()

@endpoint
async def push_weights(self):
pass
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
start_time = time.time()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
start_time = time.time()
# TODO - issues/148 followup
start_time = time.time()

Just for my future reference, tagging some pieces for observability

assert self.store is not None, "Store must be provided to save weights"
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
await push_state_dict(self.store, self.model.state_dict(), key)
end_time = time.time()
self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)


@dataclass
Expand All @@ -226,6 +235,9 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
total_reward = 0.0
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
self.logger.info(
f"Response: {response} | Target: {target} | Reward: {reward}"
)
total_reward += reward
return total_reward

Expand All @@ -239,15 +251,8 @@ async def compute(self, group: Group) -> list[float]:
rewards = torch.Tensor([[e.reward for e in group.episodes]])
mean = rewards.mean(1, keepdim=True)
std = rewards.std(1, keepdim=True)

# if std is nan, return 0s. Remove this before shipping
if std.isnan().any():
advantages = torch.zeros_like(rewards)
else:
advantages = (rewards - mean) / (std + 1e-4)

x = advantages.squeeze(0).tolist()
return x
advantages = (rewards - mean) / (std + 1e-4)
return advantages.squeeze(0).tolist()


class RefModel(ForgeActor):
Expand Down Expand Up @@ -328,10 +333,10 @@ async def pad_token(self):

async def main():
"""Main GRPO training loop with rollout and training processes."""
group_size = 4
model = "Qwen/Qwen3-1.7B-Base"
group_size = 5
model = "Qwen/Qwen3-4B-Base"
max_req_tokens = 512
max_res_tokens = 128
max_res_tokens = 512

# ---- Setup WandB Logger ---- #
logger = get_metric_logger(
Expand All @@ -340,6 +345,8 @@ async def main():
project="grpo-training",
)

store = await MultiProcessStore.create_store()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasLLC Is this still the recommended way of doing things?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we are using ts.initialize() now and there is a global singleton torchstore. But I will let @LucasLLC weigh in.


# ---- Setup services ---- #
(
dataloader,
Expand Down Expand Up @@ -368,18 +375,20 @@ async def main():
n=group_size, max_tokens=max_res_tokens
),
),
store=store,
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Trainer,
learning_rate=1e-5,
model_name=model,
store=store,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ReplayBuffer,
batch_size=4,
max_policy_age=1,
batch_size=8,
max_policy_age=0, # Fully on-policy for now
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
Expand Down Expand Up @@ -409,7 +418,13 @@ async def continuous_rollouts():
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["request"], sample["target"]
version = 0 # await policy.get_current_version.choose()
responses = await policy.generate.choose(prompt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll throw away a lot of data this way for fully on policy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For short responses yeah definitely, if you look at the WandB logs (buffer_size/rollout), you can see that we build up a buffer of about 100 episodes and then evict the majority of them back and forth during weight updates.

When we start allowing much longer generations and our models are much bigger, this won't be as big of an issue.

# If weights are updating mid-rollout, response will be cancelled and service
# will return None. We currently throw away the sample.
if responses is None:
continue

version = await policy.get_version.choose()
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
Expand All @@ -421,12 +436,10 @@ async def continuous_rollouts():
target=target,
)

responses = await policy.generate.choose(prompt)

# TODO: Parallelize the following calculation
for episode, response in zip(group.episodes, responses.outputs):
episode.request_tokens = responses.prompt_token_ids
episode.response_tokens = response.token_ids
assert len(response.token_ids) <= max_res_tokens
episode.ref_logprobs = await ref_model.forward.choose(episode)
episode.reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=response.text, target=target
Expand All @@ -436,30 +449,33 @@ async def continuous_rollouts():
episode.advantage = advantage
await replay_buffer.add.choose(episode)

avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
logger.log("avg_response_len/rollout", avg_response_len, rollout_count)
buffer_size = await replay_buffer._numel.choose()
logger.log("buffer_size/rollout", buffer_size, rollout_count)
avg_reward = sum(e.reward for e in group.episodes) / group_size
logger.log("avg_reward/rollout", avg_reward, rollout_count)

rollout_count += 1
if rollout_count % 10 == 0:
avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes)
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
policy_version = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=0)
batch = await replay_buffer.sample.choose(
curr_policy_version=policy_version
)
if batch is None:
await asyncio.sleep(0.1)
else:
training_result = await trainer.train_step.choose(batch)
loss = await trainer.train_step.choose(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be a call

training_step += 1
if training_step % 10 == 0:
print(f"Completed {training_step} training steps")
if training_result:
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)
# await trainer.update_weights(policy)
logger.log("loss/training_step", loss, training_step)
await trainer.push_weights.choose(policy_version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be call technically even though choose works since replicas=1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be call technically even though choose works since replicas=1

As a side note, even if we do call(), what we are doing here is all the trainers training on the same batch right?

The replicas are just for fault tolerance? In this regard, if we want different trainers to train on different batches, the trainers themselves have to pull the batches right?

An alternative way to do this is we split the batch into microbatches and then call choose() on each microbatch. After the whole batch is done. We then do an all_reduce (or other forms of reduce) to average the weights.

policy_version += 1
await policy.update_weights.choose()

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
Expand Down
Loading
Loading