Skip to content

Commit 2c497ba

Browse files
committed
Fix: Eliminate port collision race conditions in distributed tests
Implemented a thread-safe port reservation system to prevent EADDRINUSE errors in distributed training tests. - Created PortManager class with mutex-protected port allocation - Updated find_free_network_port() to use PortManager - Enhanced test teardown to release ports after completion - Added 24 comprehensive tests (17 unit + 7 integration) - Added context manager for automatic port cleanup Fixes port collision issues in: - tests_fabric/strategies/test_ddp_integration.py::test_clip_gradients - tests_pytorch/strategies/test_fsdp.py::test_checkpoint_multi_gpus
1 parent f58a176 commit 2c497ba

File tree

5 files changed

+739
-6
lines changed

5 files changed

+739
-6
lines changed

src/lightning/fabric/plugins/environments/lightning.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
import os
16-
import socket
1716

1817
from typing_extensions import override
1918

2019
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
20+
from lightning.fabric.utilities.port_manager import get_port_manager
2121
from lightning.fabric.utilities.rank_zero import rank_zero_only
2222

2323

@@ -111,9 +111,17 @@ def find_free_network_port() -> int:
111111
It is useful in single-node training when we don't want to connect to a real main node but have to set the
112112
`MASTER_PORT` environment variable.
113113
114+
This function uses a global port manager to prevent internal race conditions within the test suite.
115+
The allocated port is reserved and won't be returned by subsequent calls until it's explicitly released.
116+
117+
Note:
118+
While this prevents collisions between concurrent Lightning tests, external processes can still
119+
claim the port between allocation and binding. For production use, explicitly set the MASTER_PORT
120+
environment variable.
121+
122+
Returns:
123+
A port number that is reserved and free at the time of allocation
124+
114125
"""
115-
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
116-
s.bind(("", 0))
117-
port = s.getsockname()[1]
118-
s.close()
119-
return port
126+
port_manager = get_port_manager()
127+
return port_manager.allocate_port()
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Port allocation manager to prevent race conditions in distributed training."""
15+
16+
import atexit
17+
import socket
18+
import threading
19+
from collections.abc import Iterator
20+
from contextlib import contextmanager
21+
from typing import Optional
22+
23+
24+
class PortManager:
25+
"""Thread-safe port manager to prevent EADDRINUSE errors.
26+
27+
This manager maintains a global registry of allocated ports to ensure that multiple concurrent tests don't try to
28+
use the same port. While this doesn't completely eliminate the race condition with external processes, it prevents
29+
internal collisions within the test suite.
30+
31+
"""
32+
33+
def __init__(self) -> None:
34+
self._lock = threading.Lock()
35+
self._allocated_ports: set[int] = set()
36+
# Register cleanup to release all ports on exit
37+
atexit.register(self.release_all)
38+
39+
def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 100) -> int:
40+
"""Allocate a free port, ensuring it's not already reserved.
41+
42+
Args:
43+
preferred_port: If provided, try to allocate this specific port first
44+
max_attempts: Maximum number of attempts to find a free port
45+
46+
Returns:
47+
An allocated port number
48+
49+
Raises:
50+
RuntimeError: If unable to find a free port after max_attempts
51+
52+
"""
53+
with self._lock:
54+
# If a preferred port is specified and available, use it
55+
if (
56+
preferred_port is not None
57+
and preferred_port not in self._allocated_ports
58+
and self._is_port_free(preferred_port)
59+
):
60+
self._allocated_ports.add(preferred_port)
61+
return preferred_port
62+
63+
# Try to find a free port
64+
for attempt in range(max_attempts):
65+
port = self._find_free_port()
66+
67+
# Double-check it's not in our reserved set (shouldn't happen, but be safe)
68+
if port not in self._allocated_ports:
69+
self._allocated_ports.add(port)
70+
return port
71+
72+
raise RuntimeError(
73+
f"Failed to allocate a free port after {max_attempts} attempts. "
74+
f"Currently allocated ports: {len(self._allocated_ports)}"
75+
)
76+
77+
def release_port(self, port: int) -> None:
78+
"""Release a previously allocated port.
79+
80+
Args:
81+
port: Port number to release
82+
83+
"""
84+
with self._lock:
85+
self._allocated_ports.discard(port)
86+
87+
def release_all(self) -> None:
88+
"""Release all allocated ports."""
89+
with self._lock:
90+
self._allocated_ports.clear()
91+
92+
@contextmanager
93+
def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]:
94+
"""Context manager for automatic port cleanup.
95+
96+
Usage:
97+
with manager.allocated_port() as port:
98+
# Use port here
99+
pass
100+
# Port automatically released
101+
102+
Args:
103+
preferred_port: Optional preferred port number
104+
105+
Yields:
106+
Allocated port number
107+
108+
"""
109+
port = self.allocate_port(preferred_port=preferred_port)
110+
try:
111+
yield port
112+
finally:
113+
self.release_port(port)
114+
115+
@staticmethod
116+
def _find_free_port() -> int:
117+
"""Find a free port using OS allocation.
118+
119+
Returns:
120+
A port number that was free at the time of checking
121+
122+
"""
123+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
124+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
125+
s.bind(("", 0))
126+
port = s.getsockname()[1]
127+
s.close()
128+
return port
129+
130+
@staticmethod
131+
def _is_port_free(port: int) -> bool:
132+
"""Check if a specific port is available.
133+
134+
Args:
135+
port: Port number to check
136+
137+
Returns:
138+
True if the port is free, False otherwise
139+
140+
"""
141+
try:
142+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
143+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
144+
s.bind(("", port))
145+
s.close()
146+
return True
147+
except OSError:
148+
return False
149+
150+
151+
# Global singleton instance
152+
_port_manager: Optional[PortManager] = None
153+
_port_manager_lock = threading.Lock()
154+
155+
156+
def get_port_manager() -> PortManager:
157+
"""Get or create the global port manager instance.
158+
159+
Returns:
160+
The global PortManager singleton
161+
162+
"""
163+
global _port_manager
164+
if _port_manager is None:
165+
with _port_manager_lock:
166+
if _port_manager is None:
167+
_port_manager = PortManager()
168+
return _port_manager

tests/tests_fabric/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import contextlib
1415
import os
1516
import sys
1617
import threading
@@ -77,9 +78,25 @@ def restore_env_variables():
7778
@pytest.fixture(autouse=True)
7879
def teardown_process_group():
7980
"""Ensures that the distributed process group gets closed before the next test runs."""
81+
import os
82+
83+
from lightning.fabric.utilities.port_manager import get_port_manager
84+
85+
# Record the port used in this test (if any)
86+
port_to_release = None
87+
if "MASTER_PORT" in os.environ:
88+
with contextlib.suppress(ValueError, KeyError):
89+
port_to_release = int(os.environ["MASTER_PORT"])
90+
8091
yield
92+
93+
# Clean up distributed connection
8194
_destroy_dist_connection()
8295

96+
# Release the port from the manager so it can be reused
97+
if port_to_release is not None:
98+
get_port_manager().release_port(port_to_release)
99+
83100

84101
@pytest.fixture(autouse=True)
85102
def thread_police_duuu_daaa_duuu_daaa():

0 commit comments

Comments
 (0)