Skip to content

Commit b1b3adc

Browse files
committed
spawn servic based trainer
1 parent 4372a54 commit b1b3adc

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

apps/rl/llama3_8b.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ trainer:
1515
flavor: 8B
1616
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct
1717

18-
processes:
19-
scheduler: local # local | mast (not supported yet)
20-
num_hosts: 1
21-
with_gpus: True
22-
num_procs: 4
18+
#processes:
19+
# scheduler: local # local | mast (not supported yet)
20+
# num_hosts: 1
21+
# with_gpus: True
22+
# num_procs: 4
2323

2424
optimizer:
2525
name: AdamW

apps/rl/main.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,23 @@
1818

1919
from forge.cli.config import parse
2020
from forge.controller import spawn_actors
21+
22+
from forge.controller.service import ServiceConfig, spawn_service
2123
from omegaconf import DictConfig
2224

2325
logger = logging.getLogger(__name__)
2426
logger.setLevel(logging.INFO)
2527

2628

2729
async def run(cfg: DictConfig):
28-
trainer, buffer = await asyncio.gather(
29-
spawn_actors(
30-
name="trainer",
31-
actor_cls=RLTrainer,
32-
cfg=cfg.trainer,
33-
processes=cfg.trainer.pop("processes"),
34-
set_address=True,
35-
),
36-
spawn_actors(
30+
31+
trainer = await spawn_service(
32+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
33+
RLTrainer,
34+
**cfg.trainer,
35+
)
36+
buffer = (
37+
await spawn_actors(
3738
name="replay_buffer",
3839
actor_cls=ReplayBuffer,
3940
cfg=cfg.replay_buffer,
@@ -43,10 +44,7 @@ async def run(cfg: DictConfig):
4344
print("Actors spawned")
4445

4546
# Initialize everything
46-
await asyncio.gather(
47-
buffer.setup.call(),
48-
trainer.setup.call(),
49-
)
47+
await trainer.setup.call()
5048
print("Setup done")
5149

5250
print("shutting down...")

0 commit comments

Comments
 (0)