Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,15 @@ async def main(cfg: DictConfig):
ref_model,
reward_actor,
) = await asyncio.gather(
DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset),
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
RLTrainer.options(**cfg.services.trainer).as_service(
RLTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer, loss=simple_grpo_loss
),
ReplayBuffer.options(**cfg.services.replay_buffer).as_service(
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(
**cfg.replay_buffer, collate=collate
),
ComputeAdvantages.options(**cfg.services.compute_advantages).as_service(),
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
reward_functions=[MathReward(), ThinkingReward()]
Expand All @@ -283,9 +283,9 @@ async def main(cfg: DictConfig):
# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
pad_id = await dataloader.pad_token.route()
pad_id = await dataloader.pad_token.call_one()
while True:
sample = await dataloader.sample.route()
sample = await dataloader.sample.call_one()
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return
Expand Down Expand Up @@ -332,17 +332,17 @@ async def continuous_rollouts():
del ref_logits, ref_logprobs, input_ids

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.route(group)
advantages = await compute_advantages.compute.call_one(group)
for episode, advantage in zip(group.episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.route(episode)
await replay_buffer.add.call_one(episode)

# Log metrics
avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count)
buffer_size = await replay_buffer._numel.route()
buffer_size = await replay_buffer._numel.call_one()
mlogger.log("buffer_size/rollout", buffer_size, rollout_count)
avg_reward = sum(e.reward for e in group.episodes) / group_size
mlogger.log("avg_reward/rollout", avg_reward, rollout_count)
Expand All @@ -352,15 +352,18 @@ async def continuous_rollouts():
async def continuous_training():
training_step = 0
while True:
batch = await replay_buffer.sample.route(curr_policy_version=training_step)
batch = await replay_buffer.sample.call_one(
curr_policy_version=training_step
)
if batch is None:
await asyncio.sleep(0.1)
else:
inputs, targets = batch
loss = await trainer.train_step.route(inputs, targets)
loss = await trainer.train_step.call_one(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.fanout(training_step)

await trainer.push_weights.call(training_step)
await policy.update_weights.fanout(training_step)

print("Starting GRPO training loops...")
Expand All @@ -377,11 +380,11 @@ async def continuous_training():
finally:
print("Shutting down...")
await asyncio.gather(
dataloader.shutdown(),
DatasetActor.shutdown(dataloader),
policy.shutdown(),
trainer.shutdown(),
replay_buffer.shutdown(),
compute_advantages.shutdown(),
RLTrainer.shutdown(trainer),
ReplayBuffer.shutdown(replay_buffer),
ComputeAdvantages.shutdown(compute_advantages),
ref_model.shutdown(),
reward_actor.shutdown(),
)
Expand Down
22 changes: 10 additions & 12 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,29 @@ ref_model:

# All resource allocations
services:
dataset:
procs: 1
num_replicas: 1
with_gpus: false
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 1
with_gpus: true
trainer:
ref_model:
procs: 1
num_replicas: 1
with_gpus: true
replay_buffer:
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
ref_model:

actors:
dataset:
procs: 1
with_gpus: false
trainer:
procs: 1
num_replicas: 1
with_gpus: true
compute_advantages:
replay_buffer:
procs: 1
num_replicas: 1
with_gpus: false
reward_actor:
compute_advantages:
procs: 1
num_replicas: 1
with_gpus: false
24 changes: 11 additions & 13 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,31 +106,29 @@ ref_model:

# All resource allocations
services:
dataset:
procs: 1
num_replicas: 1
with_gpus: false
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 1
with_gpus: true
trainer:
procs: 2
ref_model:
procs: 1
num_replicas: 1
with_gpus: true
replay_buffer:
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
ref_model:

actors:
dataset:
procs: 1
num_replicas: 1
with_gpus: false
trainer:
procs: 2
with_gpus: true
compute_advantages:
replay_buffer:
procs: 1
num_replicas: 1
with_gpus: false
reward_actor:
compute_advantages:
procs: 1
num_replicas: 1
with_gpus: false
22 changes: 10 additions & 12 deletions apps/grpo/qwen3_multinode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,31 @@ ref_model:
model_name: ${model}

services:
dataset:
procs: 1
num_replicas: 1
with_gpus: false
policy:
procs: 1
hosts: 1
num_replicas: 1
with_gpus: true
trainer:
ref_model:
procs: 1
hosts: 1
num_replicas: 1
with_gpus: true
replay_buffer:
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false

actors:
dataset:
procs: 1
with_gpus: false
compute_advantages:
procs: 1
num_replicas: 1
with_gpus: false
ref_model:
trainer:
procs: 1
num_replicas: 1
hosts: 1
with_gpus: true
reward_actor:
replay_buffer:
procs: 1
num_replicas: 1
with_gpus: false
5 changes: 0 additions & 5 deletions apps/rl/__init__.py

This file was deleted.

62 changes: 0 additions & 62 deletions apps/rl/llama3_8b.yaml

This file was deleted.

Loading
Loading