Skip to content

Commit 9a011a6

Browse files
committed
a few more clean ups
1 parent e6f4fb6 commit 9a011a6

File tree

3 files changed

+16
-24
lines changed

3 files changed

+16
-24
lines changed

apps/grpo/main.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,13 @@
2626
from forge.actors.trainer import RLTrainer
2727
from forge.cli.config import parse
2828
from forge.controller.actor import ForgeActor
29-
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
3029
from forge.controller.provisioner import init_provisioner, shutdown
3130
from forge.data.rewards import MathReward, ThinkingReward
3231
from forge.observability.metric_actors import get_or_create_metric_logger
3332
from forge.observability.metrics import record_metric, Reduce
3433
from forge.observability.perf_tracker import Tracer
3534

36-
from forge.types import (
37-
Launcher,
38-
LauncherConfig,
39-
ProcessConfig,
40-
ProvisionerConfig,
41-
ServiceConfig,
42-
)
35+
from forge.types import LauncherConfig, ProvisionerConfig
4336
from forge.util.ops import compute_logprobs
4437
from monarch.actor import endpoint
4538
from omegaconf import DictConfig
@@ -320,25 +313,21 @@ async def main(cfg: DictConfig):
320313
max_req_tokens = cfg.max_req_tokens
321314
max_res_tokens = cfg.max_res_tokens
322315

323-
# init provisioner
324-
await init_provisioner(
325-
ProvisionerConfig(
326-
launcher_config=LauncherConfig(
327-
launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)),
328-
job_name=cfg.get(JOB_NAME_KEY, None),
329-
services={k: ServiceConfig(**v) for k, v in cfg.services.items()},
330-
actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},
316+
# ---- Global setups ---- #
317+
if cfg.get("provisioner", None) is not None:
318+
await init_provisioner(
319+
ProvisionerConfig(
320+
launcher_config=LauncherConfig(**cfg.provisioner.launcher)
331321
)
332322
)
333-
)
334-
335323
# initialize before spawning services
336324
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
337325
mlogger = await get_or_create_metric_logger()
338326
await mlogger.init_backends.call_one(metric_logging_cfg)
327+
await ts.initialize(strategy=ts.ControllerStorageVolumes())
339328

340329
# ---- Setup services ---- #
341-
await ts.initialize(strategy=ts.ControllerStorageVolumes())
330+
342331
(
343332
dataloader,
344333
policy,

apps/grpo/qwen3_32b.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ max_res_tokens: 512
1010
model: "Qwen/Qwen3-32B"
1111
off_by_n: 1 # Off by one by default
1212

13+
provisioner:
14+
launcher: slurm
15+
1316
# Main loop configuration
1417
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
1518

src/forge/controller/provisioner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,20 @@ def _get_port() -> str:
3636
return str(port)
3737

3838

39-
class _SetupActor(Actor):
39+
class _RemoteInfoFetcher(Actor):
4040
"""An actor responsible for getting remote host information."""
4141

4242
@endpoint
43-
def get_info(self) -> [str, str]:
43+
def get_info(self) -> tuple[str, str]:
4444
return socket.gethostname(), _get_port()
4545

4646

4747
async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
4848
"""Returns the host name and port of the host mesh."""
4949
throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
50-
setup_actor = throwaway_procs.spawn("_setup_actor", _SetupActor)
51-
setup_actor = setup_actor.slice(procs=0)
52-
host, port = await setup_actor.get_info.call_one()
50+
fetcher = throwaway_procs.spawn("_fetcher", _RemoteInfoFetcher)
51+
fetcher = fetcher.slice(procs=0)
52+
host, port = await fetcher.get_info.call_one()
5353
await throwaway_procs.stop()
5454
return host, port
5555

0 commit comments

Comments
 (0)