|  | 
| 28 | 28 | from forge.controller.actor import ForgeActor | 
| 29 | 29 | from forge.controller.provisioner import init_provisioner, shutdown | 
| 30 | 30 | from forge.data.rewards import MathReward, ThinkingReward | 
|  | 31 | +from forge.env import MONARCH_HOSTMESH_V1 | 
| 31 | 32 | from forge.observability.metric_actors import get_or_create_metric_logger | 
| 32 | 33 | from forge.observability.metrics import record_metric, Reduce | 
| 33 | 34 | from forge.observability.perf_tracker import Tracer | 
| @@ -314,14 +315,23 @@ async def main(cfg: DictConfig): | 
| 314 | 315 |     max_res_tokens = cfg.max_res_tokens | 
| 315 | 316 | 
 | 
| 316 | 317 |     # ---- Global setups ---- # | 
|  | 318 | +    provisioner = None | 
| 317 | 319 |     if cfg.get("provisioner", None) is not None: | 
| 318 |  | -        await init_provisioner( | 
|  | 320 | +        provisioner = await init_provisioner( | 
| 319 | 321 |             ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) | 
| 320 | 322 |         ) | 
|  | 323 | +    else: | 
|  | 324 | +        provisioner = await init_provisioner() | 
|  | 325 | + | 
| 321 | 326 |     metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) | 
| 322 | 327 |     mlogger = await get_or_create_metric_logger() | 
| 323 | 328 |     await mlogger.init_backends.call_one(metric_logging_cfg) | 
| 324 |  | -    await ts.initialize(strategy=ts.ControllerStorageVolumes()) | 
|  | 329 | + | 
|  | 330 | +    # In the host mesh v0 case, actors on remote hosts are not able to communicate | 
|  | 331 | +    # with one another. Therefore we use the controller as our storage volume. | 
|  | 332 | +    if provisioner is None or not MONARCH_HOSTMESH_V1.get_value(): | 
|  | 333 | +        await ts.initialize(strategy=ts.ControllerStorageVolumes()) | 
|  | 334 | +        print("Torchstore successfuly initialized with controller storage strategy") | 
| 325 | 335 | 
 | 
| 326 | 336 |     # ---- Setup services ---- # | 
| 327 | 337 | 
 | 
| @@ -351,6 +361,16 @@ async def main(cfg: DictConfig): | 
| 351 | 361 | 
 | 
| 352 | 362 |     print("All services initialized successfully!") | 
| 353 | 363 | 
 | 
|  | 364 | +    if provisioner is not None and MONARCH_HOSTMESH_V1.get_value(): | 
|  | 365 | +        # TODO: support multiple host meshes | 
|  | 366 | +        trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] | 
|  | 367 | +        trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) | 
|  | 368 | +        await ts.initialize( | 
|  | 369 | +            mesh=trainer_hosts.spawn_procs(per_host={"gpus": 8}), | 
|  | 370 | +            strategy=ts.LocalRankStrategy(), | 
|  | 371 | +        ) | 
|  | 372 | +        print("Torchstore successfuly initialized with local rank strategy") | 
|  | 373 | + | 
| 354 | 374 |     # ---- Core RL loops ---- # | 
| 355 | 375 |     async def continuous_rollouts(): | 
| 356 | 376 |         rollout_count = 0 | 
|  | 
0 commit comments