diff --git a/apps/rl/llama3_8b.yaml b/apps/rl/llama3_8b.yaml index ac4635575..45b5eca28 100644 --- a/apps/rl/llama3_8b.yaml +++ b/apps/rl/llama3_8b.yaml @@ -15,11 +15,6 @@ trainer: flavor: 8B tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct - processes: - scheduler: local # local | mast (not supported yet) - num_hosts: 1 - with_gpus: True - num_procs: 4 optimizer: name: AdamW @@ -65,11 +60,6 @@ replay_buffer: batch_size: 2 max_policy_age: 2 seed: None - processes: - scheduler: local # local | mast (not supported yet) - num_hosts: 1 - with_gpus: False - num_procs: 1 # policy: # scheduler: diff --git a/apps/rl/main.py b/apps/rl/main.py index 1e5e6f116..02255b063 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -15,9 +15,8 @@ import sys from forge.actors import ReplayBuffer, RLTrainer - from forge.cli.config import parse -from forge.controller import spawn_actors +from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -25,32 +24,23 @@ async def run(cfg: DictConfig): - trainer, buffer = await asyncio.gather( - spawn_actors( - name="trainer", - actor_cls=RLTrainer, - cfg=cfg.trainer, - processes=cfg.trainer.pop("processes"), - set_address=True, + trainer, replay_buffer = await asyncio.gather( + spawn_service( + ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4), + RLTrainer, + **cfg.trainer, ), - spawn_actors( - name="replay_buffer", - actor_cls=ReplayBuffer, - cfg=cfg.replay_buffer, - processes=cfg.replay_buffer.pop("processes"), + spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1), + ReplayBuffer, + **cfg.replay_buffer, ), ) - print("Actors spawned") - - # Initialize everything - await asyncio.gather( - buffer.setup.call(), - trainer.setup.call(), - ) - print("Setup done") + print("Services initialized....") print("shutting down...") - await asyncio.gather(*[a.mesh.stop() for a in [trainer]]) + await shutdown_service(trainer) + await shutdown_service(replay_buffer) @parse