diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 5ca331f32..329286a81 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -27,6 +27,8 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +_ENV_VARS_TO_INHERIT = ["TORCHSTORE_RDMA_ENABLED"] + def _get_port() -> str: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -241,6 +243,11 @@ def bootstrap(env: dict[str, str]): # Shows detailed logs for Monarch rust failures env_vars["RUST_BACKTRACE"] = "1" + for name in _ENV_VARS_TO_INHERIT: + val = os.environ.get(name) + if val is not None: + env_vars[name] = val + procs = host_mesh.spawn_procs( per_host={"gpus": num_procs}, bootstrap=functools.partial(bootstrap, env=env_vars),