Skip to content

Commit 69c6b1d

Browse files
authored
spawn service based trainer (#131)
1 parent c0c90ac commit 69c6b1d

File tree

2 files changed

+13
-33
lines changed

2 files changed

+13
-33
lines changed

apps/rl/llama3_8b.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ trainer:
1515
flavor: 8B
1616
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct
1717

18-
processes:
19-
scheduler: local # local | mast (not supported yet)
20-
num_hosts: 1
21-
with_gpus: True
22-
num_procs: 4
2318

2419
optimizer:
2520
name: AdamW
@@ -65,11 +60,6 @@ replay_buffer:
6560
batch_size: 2
6661
max_policy_age: 2
6762
seed: None
68-
processes:
69-
scheduler: local # local | mast (not supported yet)
70-
num_hosts: 1
71-
with_gpus: False
72-
num_procs: 1
7363

7464
# policy:
7565
# scheduler:

apps/rl/main.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,32 @@
1515
import sys
1616

1717
from forge.actors import ReplayBuffer, RLTrainer
18-
1918
from forge.cli.config import parse
20-
from forge.controller import spawn_actors
19+
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
2120
from omegaconf import DictConfig
2221

2322
logger = logging.getLogger(__name__)
2423
logger.setLevel(logging.INFO)
2524

2625

2726
async def run(cfg: DictConfig):
28-
trainer, buffer = await asyncio.gather(
29-
spawn_actors(
30-
name="trainer",
31-
actor_cls=RLTrainer,
32-
cfg=cfg.trainer,
33-
processes=cfg.trainer.pop("processes"),
34-
set_address=True,
27+
trainer, replay_buffer = await asyncio.gather(
28+
spawn_service(
29+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4),
30+
RLTrainer,
31+
**cfg.trainer,
3532
),
36-
spawn_actors(
37-
name="replay_buffer",
38-
actor_cls=ReplayBuffer,
39-
cfg=cfg.replay_buffer,
40-
processes=cfg.replay_buffer.pop("processes"),
33+
spawn_service(
34+
ServiceConfig(procs_per_replica=1, num_replicas=1),
35+
ReplayBuffer,
36+
**cfg.replay_buffer,
4137
),
4238
)
43-
print("Actors spawned")
44-
45-
# Initialize everything
46-
await asyncio.gather(
47-
buffer.setup.call(),
48-
trainer.setup.call(),
49-
)
50-
print("Setup done")
39+
print("Services initialized....")
5140

5241
print("shutting down...")
53-
await asyncio.gather(*[a.mesh.stop() for a in [trainer]])
42+
await shutdown_service(trainer)
43+
await shutdown_service(replay_buffer)
5444

5545

5646
@parse

0 commit comments

Comments
 (0)