From b1b3adc1e250b2d419821e2df0c625bca3df231f Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Fri, 5 Sep 2025 09:44:39 -0700 Subject: [PATCH 1/3] spawn servic based trainer --- apps/rl/llama3_8b.yaml | 10 +++++----- apps/rl/main.py | 24 +++++++++++------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/apps/rl/llama3_8b.yaml b/apps/rl/llama3_8b.yaml index ac4635575..822d6afef 100644 --- a/apps/rl/llama3_8b.yaml +++ b/apps/rl/llama3_8b.yaml @@ -15,11 +15,11 @@ trainer: flavor: 8B tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct - processes: - scheduler: local # local | mast (not supported yet) - num_hosts: 1 - with_gpus: True - num_procs: 4 + #processes: + # scheduler: local # local | mast (not supported yet) + # num_hosts: 1 + # with_gpus: True + # num_procs: 4 optimizer: name: AdamW diff --git a/apps/rl/main.py b/apps/rl/main.py index 1e5e6f116..c31b63856 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -18,6 +18,8 @@ from forge.cli.config import parse from forge.controller import spawn_actors + +from forge.controller.service import ServiceConfig, spawn_service from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -25,15 +27,14 @@ async def run(cfg: DictConfig): - trainer, buffer = await asyncio.gather( - spawn_actors( - name="trainer", - actor_cls=RLTrainer, - cfg=cfg.trainer, - processes=cfg.trainer.pop("processes"), - set_address=True, - ), - spawn_actors( + + trainer = await spawn_service( + ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), + RLTrainer, + **cfg.trainer, + ) + buffer = ( + await spawn_actors( name="replay_buffer", actor_cls=ReplayBuffer, cfg=cfg.replay_buffer, @@ -43,10 +44,7 @@ async def run(cfg: DictConfig): print("Actors spawned") # Initialize everything - await asyncio.gather( - buffer.setup.call(), - trainer.setup.call(), - ) + await trainer.setup.call() print("Setup done") print("shutting down...") From f15d3978ec4c16212c250c0758444745484f566b Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Mon, 8 Sep 2025 07:41:56 -0700 Subject: [PATCH 2/3] working RLtrainer example code, after porting to service API --- apps/rl/llama3_8b.yaml | 10 ---------- apps/rl/main.py | 27 +++++++++------------------ 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/apps/rl/llama3_8b.yaml b/apps/rl/llama3_8b.yaml index 822d6afef..45b5eca28 100644 --- a/apps/rl/llama3_8b.yaml +++ b/apps/rl/llama3_8b.yaml @@ -15,11 +15,6 @@ trainer: flavor: 8B tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct - #processes: - # scheduler: local # local | mast (not supported yet) - # num_hosts: 1 - # with_gpus: True - # num_procs: 4 optimizer: name: AdamW @@ -65,11 +60,6 @@ replay_buffer: batch_size: 2 max_policy_age: 2 seed: None - processes: - scheduler: local # local | mast (not supported yet) - num_hosts: 1 - with_gpus: False - num_procs: 1 # policy: # scheduler: diff --git a/apps/rl/main.py b/apps/rl/main.py index c31b63856..2dd55965d 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -15,11 +15,8 @@ import sys from forge.actors import ReplayBuffer, RLTrainer - from forge.cli.config import parse -from forge.controller import spawn_actors - -from forge.controller.service import ServiceConfig, spawn_service +from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -29,26 +26,20 @@ async def run(cfg: DictConfig): trainer = await spawn_service( - ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), + ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4), RLTrainer, **cfg.trainer, ) - buffer = ( - await spawn_actors( - name="replay_buffer", - actor_cls=ReplayBuffer, - cfg=cfg.replay_buffer, - processes=cfg.replay_buffer.pop("processes"), - ), + replay_buffer = await spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1), + ReplayBuffer, + **cfg.replay_buffer, ) - print("Actors spawned") - - # Initialize everything - await trainer.setup.call() - print("Setup done") + print("Services initialized....") print("shutting down...") - await asyncio.gather(*[a.mesh.stop() for a in [trainer]]) + await shutdown_service(trainer) + await shutdown_service(replay_buffer) @parse From 4e4b2798229392e70b88b3edd2ac570ed54ae065 Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Mon, 8 Sep 2025 08:56:31 -0700 Subject: [PATCH 3/3] use asyncio.gather --- apps/rl/main.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/apps/rl/main.py b/apps/rl/main.py index 2dd55965d..02255b063 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -24,16 +24,17 @@ async def run(cfg: DictConfig): - - trainer = await spawn_service( - ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4), - RLTrainer, - **cfg.trainer, - ) - replay_buffer = await spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - ReplayBuffer, - **cfg.replay_buffer, + trainer, replay_buffer = await asyncio.gather( + spawn_service( + ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4), + RLTrainer, + **cfg.trainer, + ), + spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1), + ReplayBuffer, + **cfg.replay_buffer, + ), ) print("Services initialized....")