|
28 | 28 | from forge.controller.actor import ForgeActor |
29 | 29 | from forge.controller.provisioner import init_provisioner, shutdown |
30 | 30 | from forge.data.rewards import MathReward, ThinkingReward |
| 31 | +from forge.env import MONARCH_HOSTMESH_V1 |
31 | 32 | from forge.observability.metric_actors import get_or_create_metric_logger |
32 | 33 | from forge.observability.metrics import record_metric, Reduce |
33 | 34 | from forge.observability.perf_tracker import Tracer |
@@ -314,14 +315,23 @@ async def main(cfg: DictConfig): |
314 | 315 | max_res_tokens = cfg.max_res_tokens |
315 | 316 |
|
316 | 317 | # ---- Global setups ---- # |
| 318 | + provisioner = None |
317 | 319 | if cfg.get("provisioner", None) is not None: |
318 | | - await init_provisioner( |
| 320 | + provisioner = await init_provisioner( |
319 | 321 | ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) |
320 | 322 | ) |
| 323 | + else: |
| 324 | + provisioner = await init_provisioner() |
| 325 | + |
321 | 326 | metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) |
322 | 327 | mlogger = await get_or_create_metric_logger() |
323 | 328 | await mlogger.init_backends.call_one(metric_logging_cfg) |
324 | | - await ts.initialize(strategy=ts.ControllerStorageVolumes()) |
| 329 | + |
| 330 | + # In the host mesh v0 case, actors on remote hosts are not able to communicate |
| 331 | + # with one another. Therefore we use the controller as our storage volume. |
| 332 | + if not MONARCH_HOSTMESH_V1.get_value(): |
| 333 | + await ts.initialize(strategy=ts.ControllerStorageVolumes()) |
| 334 | + print("Torchstore successfully initialized with controller storage strategy") |
325 | 335 |
|
326 | 336 | # ---- Setup services ---- # |
327 | 337 |
|
@@ -351,6 +361,22 @@ async def main(cfg: DictConfig): |
351 | 361 |
|
352 | 362 | print("All services initialized successfully!") |
353 | 363 |
|
| 364 | + # In the HostMesh v1 case, we spawn a torchstore storage volume |
| 365 | + # per trainer process. |
| 366 | + # We initialize after service initialization because torchstore currently |
| 367 | + # requires access to the underlying proc meshes in the local rank strategy. |
| 368 | + # We should be able to hide this in the future. |
| 369 | + if MONARCH_HOSTMESH_V1.get_value(): |
| 370 | + # TODO: support multiple host meshes |
| 371 | + trainer_num_procs = cfg.actors.trainer["procs"] |
| 372 | + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] |
| 373 | + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) |
| 374 | + await ts.initialize( |
| 375 | + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), |
| 376 | + strategy=ts.LocalRankStrategy(), |
| 377 | + ) |
| 378 | + print("Torchstore successfully initialized with local rank strategy") |
| 379 | + |
354 | 380 | # ---- Core RL loops ---- # |
355 | 381 | async def continuous_rollouts(): |
356 | 382 | rollout_count = 0 |
|
0 commit comments