|
15 | 15 | import sys |
16 | 16 |
|
17 | 17 | from forge.actors import ReplayBuffer, RLTrainer |
18 | | - |
19 | 18 | from forge.cli.config import parse |
20 | | -from forge.controller import spawn_actors |
21 | | - |
22 | | -from forge.controller.service import ServiceConfig, spawn_service |
| 19 | +from forge.controller.service import ServiceConfig, shutdown_service, spawn_service |
23 | 20 | from omegaconf import DictConfig |
24 | 21 |
|
25 | 22 | logger = logging.getLogger(__name__) |
|
29 | 26 | async def run(cfg: DictConfig): |
30 | 27 |
|
31 | 28 | trainer = await spawn_service( |
32 | | - ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), |
| 29 | + ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4), |
33 | 30 | RLTrainer, |
34 | 31 | **cfg.trainer, |
35 | 32 | ) |
36 | | - buffer = ( |
37 | | - await spawn_actors( |
38 | | - name="replay_buffer", |
39 | | - actor_cls=ReplayBuffer, |
40 | | - cfg=cfg.replay_buffer, |
41 | | - processes=cfg.replay_buffer.pop("processes"), |
42 | | - ), |
| 33 | + replay_buffer = await spawn_service( |
| 34 | + ServiceConfig(procs_per_replica=1, num_replicas=1), |
| 35 | + ReplayBuffer, |
| 36 | + **cfg.replay_buffer, |
43 | 37 | ) |
44 | | - print("Actors spawned") |
45 | | - |
46 | | - # Initialize everything |
47 | | - await trainer.setup.call() |
48 | | - print("Setup done") |
| 38 | + print("Services initialized....") |
49 | 39 |
|
50 | 40 | print("shutting down...") |
51 | | - await asyncio.gather(*[a.mesh.stop() for a in [trainer]]) |
| 41 | + await shutdown_service(trainer) |
| 42 | + await shutdown_service(replay_buffer) |
52 | 43 |
|
53 | 44 |
|
54 | 45 | @parse |
|
0 commit comments