Skip to content

Commit 92d5a9e

Browse files
committed
Fabric: reserve externally provided MASTER_PORT values
1 parent 2c497ba commit 92d5a9e

File tree

6 files changed

+444
-39
lines changed

6 files changed

+444
-39
lines changed

src/lightning/fabric/plugins/environments/lightning.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,24 +104,38 @@ def teardown(self) -> None:
104104
if "WORLD_SIZE" in os.environ:
105105
del os.environ["WORLD_SIZE"]
106106

107+
if self._main_port != -1:
108+
get_port_manager().release_port(self._main_port)
109+
self._main_port = -1
110+
111+
os.environ.pop("MASTER_PORT", None)
112+
os.environ.pop("MASTER_ADDR", None)
113+
107114

108115
def find_free_network_port() -> int:
109116
"""Finds a free port on localhost.
110117
111118
It is useful in single-node training when we don't want to connect to a real main node but have to set the
112119
`MASTER_PORT` environment variable.
113120
114-
This function uses a global port manager to prevent internal race conditions within the test suite.
115121
The allocated port is reserved and won't be returned by subsequent calls until it's explicitly released.
116122
117-
Note:
118-
While this prevents collisions between concurrent Lightning tests, external processes can still
119-
claim the port between allocation and binding. For production use, explicitly set the MASTER_PORT
120-
environment variable.
121-
122123
Returns:
123124
A port number that is reserved and free at the time of allocation
124125
125126
"""
127+
# If an external launcher already specified a MASTER_PORT (for example, torch.distributed.spawn or
128+
# multiprocessing helpers), reserve it through the port manager so no other test reuses the same number.
129+
if "MASTER_PORT" in os.environ:
130+
master_port_str = os.environ["MASTER_PORT"]
131+
try:
132+
existing_port = int(master_port_str)
133+
except ValueError:
134+
pass
135+
else:
136+
port_manager = get_port_manager()
137+
if port_manager.reserve_existing_port(existing_port):
138+
return existing_port
139+
126140
port_manager = get_port_manager()
127141
return port_manager.allocate_port()

src/lightning/fabric/utilities/port_manager.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@
1616
import atexit
1717
import socket
1818
import threading
19+
from collections import deque
1920
from collections.abc import Iterator
2021
from contextlib import contextmanager
2122
from typing import Optional
2223

24+
# Size of the recently released ports queue
25+
# This prevents immediate reuse of ports that were just released
26+
# Increased to 1024 to reduce the chance of cycling back to TIME_WAIT ports
27+
_RECENTLY_RELEASED_PORTS_MAXLEN = 1024
28+
2329

2430
class PortManager:
2531
"""Thread-safe port manager to prevent EADDRINUSE errors.
@@ -33,10 +39,12 @@ class PortManager:
3339
def __init__(self) -> None:
3440
self._lock = threading.Lock()
3541
self._allocated_ports: set[int] = set()
42+
# Recently released ports are kept in a queue to avoid immediate reuse
43+
self._recently_released: deque[int] = deque(maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN)
3644
# Register cleanup to release all ports on exit
3745
atexit.register(self.release_all)
3846

39-
def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 100) -> int:
47+
def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 1000) -> int:
4048
"""Allocate a free port, ensuring it's not already reserved.
4149
4250
Args:
@@ -55,23 +63,27 @@ def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int
5563
if (
5664
preferred_port is not None
5765
and preferred_port not in self._allocated_ports
66+
and preferred_port not in self._recently_released
5867
and self._is_port_free(preferred_port)
5968
):
6069
self._allocated_ports.add(preferred_port)
6170
return preferred_port
6271

63-
# Try to find a free port
72+
# Let the OS choose a free port, but verify it's not in our tracking structures
73+
# The OS naturally avoids ports in TIME_WAIT (without SO_REUSEADDR)
6474
for attempt in range(max_attempts):
6575
port = self._find_free_port()
6676

67-
# Double-check it's not in our reserved set (shouldn't happen, but be safe)
68-
if port not in self._allocated_ports:
77+
# Skip if already allocated by us or recently released
78+
# This prevents race conditions within our process
79+
if port not in self._allocated_ports and port not in self._recently_released:
6980
self._allocated_ports.add(port)
7081
return port
7182

7283
raise RuntimeError(
7384
f"Failed to allocate a free port after {max_attempts} attempts. "
74-
f"Currently allocated ports: {len(self._allocated_ports)}"
85+
f"Currently allocated: {len(self._allocated_ports)}, "
86+
f"recently released: {len(self._recently_released)}"
7587
)
7688

7789
def release_port(self, port: int) -> None:
@@ -82,12 +94,43 @@ def release_port(self, port: int) -> None:
8294
8395
"""
8496
with self._lock:
85-
self._allocated_ports.discard(port)
97+
if port in self._allocated_ports:
98+
self._allocated_ports.remove(port)
99+
# Add to the back of the queue; oldest will be evicted when queue is full
100+
self._recently_released.append(port)
86101

87102
def release_all(self) -> None:
88103
"""Release all allocated ports."""
89104
with self._lock:
90105
self._allocated_ports.clear()
106+
self._recently_released.clear()
107+
108+
def reserve_existing_port(self, port: int) -> bool:
109+
"""Reserve a port that was allocated externally.
110+
111+
Args:
112+
port: The externally assigned port to reserve.
113+
114+
Returns:
115+
True if the port was reserved (or already reserved), False if the port value is invalid.
116+
117+
"""
118+
if port <= 0 or port > 65535:
119+
return False
120+
121+
with self._lock:
122+
if port in self._allocated_ports:
123+
return True
124+
125+
# Remove from recently released queue if present (we're explicitly reserving it)
126+
if port in self._recently_released:
127+
# Create a new deque without this port
128+
self._recently_released = deque(
129+
(p for p in self._recently_released if p != port), maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN
130+
)
131+
132+
self._allocated_ports.add(port)
133+
return True
91134

92135
@contextmanager
93136
def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]:
@@ -121,7 +164,8 @@ def _find_free_port() -> int:
121164
122165
"""
123166
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
124-
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
167+
# Don't use SO_REUSEADDR - we need to match the behavior of TCPStore
168+
# which binds without it, so ports in TIME_WAIT will be rejected
125169
s.bind(("", 0))
126170
port = s.getsockname()[1]
127171
s.close()
@@ -140,7 +184,8 @@ def _is_port_free(port: int) -> bool:
140184
"""
141185
try:
142186
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
143-
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
187+
# Don't use SO_REUSEADDR - we need to match the behavior of TCPStore
188+
# which binds without it, so ports in TIME_WAIT will be rejected
144189
s.bind(("", port))
145190
s.close()
146191
return True

tests/tests_fabric/conftest.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,20 @@ def teardown_process_group():
8282

8383
from lightning.fabric.utilities.port_manager import get_port_manager
8484

85-
# Record the port used in this test (if any)
86-
port_to_release = None
87-
if "MASTER_PORT" in os.environ:
88-
with contextlib.suppress(ValueError, KeyError):
89-
port_to_release = int(os.environ["MASTER_PORT"])
90-
9185
yield
9286

9387
# Clean up distributed connection
9488
_destroy_dist_connection()
9589

96-
# Release the port from the manager so it can be reused
97-
if port_to_release is not None:
98-
get_port_manager().release_port(port_to_release)
90+
manager = get_port_manager()
91+
92+
# If a process group created or updated MASTER_PORT during the test, reserve it and then clear it
93+
if "MASTER_PORT" in os.environ:
94+
with contextlib.suppress(ValueError):
95+
port = int(os.environ["MASTER_PORT"])
96+
manager.reserve_existing_port(port)
97+
manager.release_port(port)
98+
os.environ.pop("MASTER_PORT", None)
9999

100100

101101
@pytest.fixture(autouse=True)
@@ -221,6 +221,56 @@ def leave_no_artifacts_behind():
221221
assert not difference, f"Test left artifacts behind: {difference}"
222222

223223

224+
def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None:
225+
"""Retry tests that fail with EADDRINUSE errors.
226+
227+
This handles the race condition where a port might enter TIME_WAIT between our port allocation check and actual
228+
binding by TCPStore.
229+
230+
"""
231+
if call.excinfo is not None and call.when == "call":
232+
exception_msg = str(call.excinfo.value)
233+
# Check if this is an EADDRINUSE error from distributed training
234+
if "EADDRINUSE" in exception_msg or "address already in use" in exception_msg.lower():
235+
# Get the retry count from the test node
236+
retry_count = getattr(item, "_port_retry_count", 0)
237+
max_retries = 3
238+
239+
if retry_count < max_retries:
240+
# Increment retry counter
241+
item._port_retry_count = retry_count + 1
242+
243+
# Log the retry
244+
if hasattr(item.config, "get_terminal_writer"):
245+
writer = item.config.get_terminal_writer()
246+
writer.write(
247+
f"\n[Port conflict detected] Retrying test {item.name} "
248+
f"(attempt {retry_count + 2}/{max_retries + 1})...\n",
249+
yellow=True,
250+
)
251+
252+
# Clear the port manager's state to get fresh ports
253+
from lightning.fabric.utilities.port_manager import get_port_manager
254+
255+
manager = get_port_manager()
256+
manager.release_all()
257+
258+
# Re-run the test by raising Rerun exception
259+
# Note: This requires pytest-rerunfailures plugin
260+
import time
261+
262+
time.sleep(0.5) # Brief delay to let ports settle
263+
264+
# If pytest-rerunfailures is available, use it
265+
try:
266+
from pytest_rerunfailures import Rerun
267+
268+
raise Rerun(f"Port conflict (EADDRINUSE), retry {retry_count + 1}/{max_retries}")
269+
except ImportError:
270+
# Plugin not available, just let the test fail
271+
pass
272+
273+
224274
def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None:
225275
"""An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`"""
226276
initial_size = len(items)

tests/tests_fabric/plugins/environments/test_lightning.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818

1919
from lightning.fabric.plugins.environments import LightningEnvironment
20+
from lightning.fabric.utilities.port_manager import get_port_manager
2021

2122

2223
@mock.patch.dict(os.environ, {}, clear=True)
@@ -84,5 +85,18 @@ def test_teardown():
8485
assert "WORLD_SIZE" not in os.environ
8586

8687

88+
@mock.patch.dict(os.environ, {}, clear=True)
89+
def test_teardown_releases_port_and_env():
90+
env = LightningEnvironment()
91+
port = env.main_port
92+
assert port in get_port_manager()._allocated_ports
93+
94+
env.teardown()
95+
96+
assert port not in get_port_manager()._allocated_ports
97+
assert "MASTER_PORT" not in os.environ
98+
assert "MASTER_ADDR" not in os.environ
99+
100+
87101
def test_detect():
88102
assert LightningEnvironment.detect()

0 commit comments

Comments
 (0)