Skip to content

Commit f15d397

Browse files
committed
working RLtrainer example code, after porting to service API
1 parent b1b3adc commit f15d397

File tree

2 files changed

+9
-28
lines changed

2 files changed

+9
-28
lines changed

apps/rl/llama3_8b.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ trainer:
1515
flavor: 8B
1616
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct
1717

18-
#processes:
19-
# scheduler: local # local | mast (not supported yet)
20-
# num_hosts: 1
21-
# with_gpus: True
22-
# num_procs: 4
2318

2419
optimizer:
2520
name: AdamW
@@ -65,11 +60,6 @@ replay_buffer:
6560
batch_size: 2
6661
max_policy_age: 2
6762
seed: None
68-
processes:
69-
scheduler: local # local | mast (not supported yet)
70-
num_hosts: 1
71-
with_gpus: False
72-
num_procs: 1
7363

7464
# policy:
7565
# scheduler:

apps/rl/main.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
import sys
1616

1717
from forge.actors import ReplayBuffer, RLTrainer
18-
1918
from forge.cli.config import parse
20-
from forge.controller import spawn_actors
21-
22-
from forge.controller.service import ServiceConfig, spawn_service
19+
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
2320
from omegaconf import DictConfig
2421

2522
logger = logging.getLogger(__name__)
@@ -29,26 +26,20 @@
2926
async def run(cfg: DictConfig):
3027

3128
trainer = await spawn_service(
32-
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
29+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4),
3330
RLTrainer,
3431
**cfg.trainer,
3532
)
36-
buffer = (
37-
await spawn_actors(
38-
name="replay_buffer",
39-
actor_cls=ReplayBuffer,
40-
cfg=cfg.replay_buffer,
41-
processes=cfg.replay_buffer.pop("processes"),
42-
),
33+
replay_buffer = await spawn_service(
34+
ServiceConfig(procs_per_replica=1, num_replicas=1),
35+
ReplayBuffer,
36+
**cfg.replay_buffer,
4337
)
44-
print("Actors spawned")
45-
46-
# Initialize everything
47-
await trainer.setup.call()
48-
print("Setup done")
38+
print("Services initialized....")
4939

5040
print("shutting down...")
51-
await asyncio.gather(*[a.mesh.stop() for a in [trainer]])
41+
await shutdown_service(trainer)
42+
await shutdown_service(replay_buffer)
5243

5344

5445
@parse

0 commit comments

Comments
 (0)