Skip to content

Commit 5935260

Browse files
committed
[WIP] Skeleton of GRPO
1 parent 4d6dfc7 commit 5935260

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

apps/grpo/main.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import asyncio
2+
3+
from forge.actors.policy import Policy, PolicyRouter
4+
from forge.controller import ServiceConfig, spawn_service
5+
from forge.data.replay_buffer import ReplayBuffer
6+
from monarch.actor import Actor, endpoint
7+
8+
9+
class Trainer(Actor):
10+
def __init__(self):
11+
pass
12+
13+
@endpoint
14+
async def train_step(self, batch):
15+
pass
16+
17+
@endpoint
18+
async def update_weights(self):
19+
pass
20+
21+
22+
async def generate_rollout():
23+
pass
24+
25+
26+
async def main():
27+
# ---- Setup services ---- #
28+
default_service_cfg = ServiceConfig(
29+
procs_per_replica=1,
30+
min_replicas=1,
31+
max_replicas=1,
32+
default_replicas=1,
33+
)
34+
policy = await spawn_service(
35+
default_service_cfg,
36+
PolicyRouter,
37+
policy=Policy(model="Deepseek/Deepseek-v3"),
38+
)
39+
trainer = await spawn_service(
40+
default_service_cfg,
41+
Trainer,
42+
)
43+
replay_buffer = await spawn_service(
44+
default_service_cfg,
45+
ReplayBuffer,
46+
batch_size=4,
47+
max_policy_age=1,
48+
)
49+
50+
async def continuous_rollouts():
51+
while True:
52+
current_version = await policy.get_current_version()
53+
episode = await generate_rollout()
54+
await replay_buffer.add.call(episode)
55+
56+
rollout_task = asyncio.create_task(continuous_rollouts())
57+
58+
async def continuous_training():
59+
while True:
60+
batch = await replay_buffer.sample.call()
61+
if batch is not None:
62+
await trainer.train_step.call(batch)
63+
await trainer.update_policy.call()
64+
65+
training_task = asyncio.create_task(continuous_training())
66+
67+
await asyncio.gather(rollout_task, training_task)
68+
69+
70+
if __name__ == "__main__":
71+
asyncio.run(main())

0 commit comments

Comments
 (0)