Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 171 additions & 163 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM, push_state_dict
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +69,13 @@ class Group:
class Episode:
"""Episode container for GRPO rollouts."""

def __init__(self, episode_id: int, prompt: str, target: str, policy_version: int):
def __init__(
self,
episode_id: int,
prompt: str,
target: str,
policy_version: int | None = None,
):
self.episode_id = episode_id
self.prompt = prompt
self.target = target
Expand All @@ -87,11 +95,13 @@ def __init__(
beta: float = 0.1,
model_name: str = "",
device: torch.device | None = None,
store: MultiProcessStore | None = None,
):
super().__init__()
self.learning_rate = learning_rate
self.beta = beta # KL penalty coefficient
self.model_name = model_name
self.store = store

# Set device
if device is None:
Expand Down Expand Up @@ -189,31 +199,13 @@ async def train_step(self, batch: list[Episode]):
return {"loss": avg_loss, "groups_processed": num_groups_processed}

@endpoint
async def update_weights(self, policy_actor):
async def save_weights(self):
"""Update policy model weights with trainer's current weights."""
# Time how long it takes to update weights
start_time = time.time()

# Set model to eval mode for weight extraction
self.model.eval()

# Extract current model state dict
model_state_dict = self.model.state_dict()

# Convert tensors to CPU for transfer (if they're on GPU)
cpu_state_dict = {}
for key, tensor in model_state_dict.items():
cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor

# Update the policy actor's model weights
await policy_actor.update_model_weights.choose(cpu_state_dict)

# Set model back to training mode
self.model.train()

# Log the time taken
assert self.store is not None, "Store must be provided to save weights"
await push_state_dict(self.store, self.model.state_dict(), "model_state_dict")
end_time = time.time()
self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds")
self.logger.info(f"Saving weights took {end_time - start_time:.2f} seconds")


class RewardActor(ForgeActor):
Expand Down Expand Up @@ -347,158 +339,174 @@ async def main():
group_size = 1
model = "Qwen/Qwen3-1.7B"

# ---- Setup WandB Logger ---- #
logger = get_metric_logger(
"wandb",
freq=1,
project="grpo-training",
)
store = await MultiProcessStore.create_store()

# ---- Setup services ---- #
(
dataloader,
policy,
trainer,
replay_buffer,
compute_advantages,
ref_model,
reward_actor,
) = await asyncio.gather(
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
DatasetActor,
path="openai/gsm8k",
config_name="main",
split="train",
streaming=True,
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
config=PolicyConfig(
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(
num_samples=group_size, max_tokens=16
),
),
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Trainer,
learning_rate=1e-5,
beta=0.1,
model_name=model,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ReplayBuffer,
batch_size=4,
max_policy_age=1,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ComputeAdvantages,
gamma=0.99,
lambda_=0.95,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
RefModel,
model_name=model,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
RewardActor,
reward_functions=[MathReward(), ThinkingReward()],
),
# (
# dataloader,
# policy,
# trainer,
# replay_buffer,
# compute_advantages,
# ref_model,
# reward_actor,
# ) = await asyncio.gather(
# spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1),
# DatasetActor,
# path="openai/gsm8k",
# config_name="main",
# split="train",
# streaming=True,
# ),
# spawn_service(
# ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
# Policy,
# config=PolicyConfig(
# worker_params=WorkerConfig(model=model),
# sampling_params=SamplingOverrides(
# num_samples=group_size, max_tokens=16
# ),
# ),
# ),
# spawn_service(
# ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
# Trainer,
# learning_rate=1e-5,
# beta=0.1,
# model_name=model,
# store=store,
# ),
# spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1),
# ReplayBuffer,
# batch_size=4,
# max_policy_age=1,
# ),
# spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1),
# ComputeAdvantages,
# gamma=0.99,
# lambda_=0.95,
# ),
# spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
# RefModel,
# model_name=model,
# ),
# spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1),
# RewardActor,
# reward_functions=[MathReward(), ThinkingReward()],
# ),
# )

trainer = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
Trainer,
learning_rate=1e-5,
beta=0.1,
model_name=model,
store=store,
)

print("All services initialized successfully!")

# print("Trying to save weights to torchstore...")
await trainer.save_weights.choose()

# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
while True:
sample = await dataloader.__next__.choose()
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["question"], sample["answer"]
version = 0 # await policy.get_current_version.choose()
episode = Episode(
episode_id=rollout_count,
prompt=prompt,
target=target,
policy_version=version,
)
actions = await policy.generate.choose(prompt)
for action in actions:
ref_logprobs = await ref_model.forward.choose(action.token_ids)
reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=action.text, target=target
)
episode.add_group(
Group(
response=action.text,
ref_logprobs=ref_logprobs,
reward=reward,
)
)

advantages = await compute_advantages.__call__.choose(episode.groups)
for advantage, group in zip(advantages, episode.groups):
group.advantage = advantage

await replay_buffer.add.choose(episode)

rollout_count += 1
if rollout_count % 10 == 0:
avg_reward = sum(group.reward for group in episode.groups) / len(
episode.groups
)
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=0)
if batch is None:
await asyncio.sleep(0.1)
else:
training_result = await trainer.train_step.choose(batch)
training_step += 1
if training_step % 10 == 0:
print(f"Completed {training_step} training steps")
if training_result:
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)
# await trainer.update_weights(policy)

print("Starting GRPO training loops...")
rollout_task = asyncio.create_task(continuous_rollouts())
training_task = asyncio.create_task(continuous_training())

try:
await asyncio.gather(rollout_task, training_task)
except KeyboardInterrupt:
print("Training interrupted by user")
rollout_task.cancel()
training_task.cancel()
finally:
print("Shutting down...")
await asyncio.gather(
shutdown_service(policy),
shutdown_service(trainer),
shutdown_service(replay_buffer),
shutdown_service(dataloader),
shutdown_service(compute_advantages),
shutdown_service(ref_model),
shutdown_service(reward_actor),
)
# async def continuous_rollouts():
# rollout_count = 0
# while True:
# sample = await dataloader.__next__.choose()
# if sample is None:
# print("Dataloader is empty, exiting continuous rollout")
# return
# prompt, target = sample["question"], sample["answer"]
# actions, policy_version = await policy.generate.choose(prompt)
# episode = Episode(
# episode_id=rollout_count,
# prompt=prompt,
# target=target,
# policy_version=policy_version,
# )
# for action in actions:
# ref_logprobs = await ref_model.forward.choose(action.token_ids)
# reward = await reward_actor.evaluate_response.choose(
# prompt=prompt, response=action.text, target=target
# )
# episode.add_group(
# Group(
# response=action.text,
# ref_logprobs=ref_logprobs,
# reward=reward,
# )
# )

# advantages = await compute_advantages.__call__.choose(episode.groups)
# for advantage, group in zip(advantages, episode.groups):
# group.advantage = advantage

# await replay_buffer.add.choose(episode)

# rollout_count += 1
# if rollout_count % 10 == 0:
# avg_reward = sum(group.reward for group in episode.groups) / len(
# episode.groups
# )
# print(
# f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
# )
# logger.log("reward/rollout", avg_reward, rollout_count)

# async def continuous_training():
# on_policy_version = 0
# training_step = 0
# while True:
# batch = await replay_buffer.sample.choose(
# curr_policy_version=on_policy_version
# )
# if batch is None:
# await asyncio.sleep(0.1)
# else:
# training_result = await trainer.train_step.choose(batch)
# training_step += 1
# if training_step % 10 == 0:
# loss_value = training_result.get("loss", 0.0)
# print(
# f"Completed {training_step} training steps w/ loss: {loss_value}"
# )
# logger.log("loss/training_step", loss_value, training_step)
# print("Updating policy weights...")
# await trainer.save_weights.choose()

# # print("Starting GRPO training loops...")
# rollout_task = asyncio.create_task(continuous_rollouts())
# training_task = asyncio.create_task(continuous_training())

# try:
# await asyncio.gather(rollout_task, training_task)
# except KeyboardInterrupt:
# print("Training interrupted by user")
# rollout_task.cancel()
# training_task.cancel()
# finally:
# print("Shutting down...")
# await asyncio.gather(
# shutdown_service(policy),
# shutdown_service(trainer),
# shutdown_service(replay_buffer),
# shutdown_service(dataloader),
# shutdown_service(compute_advantages),
# shutdown_service(ref_model),
# shutdown_service(reward_actor),
# )


if __name__ == "__main__":
Expand Down
Loading
Loading