|
26 | 26 | from forge.actors.trainer import RLTrainer |
27 | 27 | from forge.cli.config import parse |
28 | 28 | from forge.controller.actor import ForgeActor |
29 | | -from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY |
30 | 29 | from forge.controller.provisioner import init_provisioner, shutdown |
31 | 30 | from forge.data.rewards import MathReward, ThinkingReward |
32 | 31 | from forge.observability.metric_actors import get_or_create_metric_logger |
33 | 32 | from forge.observability.metrics import record_metric, Reduce |
34 | 33 | from forge.observability.perf_tracker import Tracer |
35 | 34 |
|
36 | | -from forge.types import ( |
37 | | - Launcher, |
38 | | - LauncherConfig, |
39 | | - ProcessConfig, |
40 | | - ProvisionerConfig, |
41 | | - ServiceConfig, |
42 | | -) |
| 35 | +from forge.types import LauncherConfig, ProvisionerConfig |
43 | 36 | from forge.util.ops import compute_logprobs |
44 | 37 | from monarch.actor import endpoint |
45 | 38 | from omegaconf import DictConfig |
@@ -320,25 +313,21 @@ async def main(cfg: DictConfig): |
320 | 313 | max_req_tokens = cfg.max_req_tokens |
321 | 314 | max_res_tokens = cfg.max_res_tokens |
322 | 315 |
|
323 | | - # init provisioner |
324 | | - await init_provisioner( |
325 | | - ProvisionerConfig( |
326 | | - launcher_config=LauncherConfig( |
327 | | - launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)), |
328 | | - job_name=cfg.get(JOB_NAME_KEY, None), |
329 | | - services={k: ServiceConfig(**v) for k, v in cfg.services.items()}, |
330 | | - actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()}, |
| 316 | + # ---- Global setups ---- # |
| 317 | + if cfg.get("provisioner", None) is not None: |
| 318 | + await init_provisioner( |
| 319 | + ProvisionerConfig( |
| 320 | + launcher_config=LauncherConfig(**cfg.provisioner.launcher) |
331 | 321 | ) |
332 | 322 | ) |
333 | | - ) |
334 | | - |
335 | 323 | # initialize before spawning services |
336 | 324 | metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) |
337 | 325 | mlogger = await get_or_create_metric_logger() |
338 | 326 | await mlogger.init_backends.call_one(metric_logging_cfg) |
| 327 | + await ts.initialize(strategy=ts.ControllerStorageVolumes()) |
339 | 328 |
|
340 | 329 | # ---- Setup services ---- # |
341 | | - await ts.initialize(strategy=ts.ControllerStorageVolumes()) |
| 330 | + |
342 | 331 | ( |
343 | 332 | dataloader, |
344 | 333 | policy, |
|
0 commit comments