Skip to content

Commit a668434

Browse files
committed
update all mains
1 parent 63e7e78 commit a668434

File tree

8 files changed

+684
-679
lines changed

8 files changed

+684
-679
lines changed

.meta/mast/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
MastLauncher,
1616
mount_mnt_directory,
1717
)
18-
from forge.controller.provisioner import init_provisioner
18+
from forge.controller.provisioner import get_or_create_provisioner
1919

2020
from forge.types import (
2121
Launcher,
@@ -68,7 +68,9 @@ async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None)
6868
else:
6969
# In remote mode, we're already running inside MAST, so mount directory, init provisioner and run training
7070
mount_mnt_directory("/mnt/wsfuse")
71-
await init_provisioner(ProvisionerConfig(launcher_config=launcher_config))
71+
await get_or_create_provisioner(
72+
ProvisionerConfig(launcher_config=launcher_config)
73+
)
7274
await grpo_main(cfg)
7375

7476

apps/grpo/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from forge.actors.replay_buffer import ReplayBuffer
2626
from forge.actors.trainer import RLTrainer
2727
from forge.controller.actor import ForgeActor
28-
from forge.controller.provisioner import init_provisioner, shutdown
28+
from forge.controller.provisioner import get_or_create_provisioner, shutdown
2929
from forge.data.rewards import MathReward, ThinkingReward
3030
from forge.data_models.completion import Completion
3131
from forge.observability.metric_actors import get_or_create_metric_logger
@@ -298,11 +298,11 @@ async def main(cfg: DictConfig):
298298
# ---- Global setups ---- #
299299
provisioner = None
300300
if cfg.get("provisioner", None) is not None:
301-
provisioner = await init_provisioner(
301+
provisioner = await get_or_create_provisioner(
302302
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
303303
)
304304
else:
305-
provisioner = await init_provisioner()
305+
provisioner = await get_or_create_provisioner()
306306

307307
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
308308
mlogger = await get_or_create_metric_logger(process_name="Controller")

0 commit comments

Comments
 (0)