diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 93807d0e4..a281fb414 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -150,7 +150,6 @@ def __post_init__(self): async def launch( # pyright: ignore[reportIncompatibleMethodOverride] cls: type["Policy"], *, - process_config: ProcessConfig, engine_config: EngineConfig | Mapping = EngineConfig(), sampling_config: SamplingConfig | Mapping = SamplingConfig(), available_devices: str | None = None, @@ -158,6 +157,11 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] ) -> "Policy": # Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES # automatically. + process_config: ProcessConfig = ProcessConfig( + procs=cls.procs, + hosts=cls.hosts, + with_gpus=cls.with_gpus, + ) worker_procs = await get_proc_mesh(process_config=process_config) # TODO - issues/144 we will want to ensure colocation with workers