Skip to content

Commit 282d6cf

Browse files
committed
borrow code from pr:21239
1 parent 4f274e2 commit 282d6cf

File tree

3 files changed

+255
-0
lines changed

3 files changed

+255
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Port allocation manager to prevent race conditions in distributed training."""
15+
16+
import atexit
17+
import socket
18+
import threading
19+
from collections import deque
20+
from collections.abc import Iterator
21+
from contextlib import contextmanager
22+
from typing import Optional
23+
24+
# Maximum number of recently released ports to track before reuse
25+
_RECENTLY_RELEASED_PORTS_MAXLEN = 256
26+
27+
28+
class PortManager:
29+
"""Thread-safe port manager to prevent EADDRINUSE errors.
30+
31+
This manager maintains a global registry of allocated ports to ensure that multiple concurrent tests don't try to
32+
use the same port. While this doesn't completely eliminate the race condition with external processes, it prevents
33+
internal collisions within the test suite.
34+
35+
"""
36+
37+
def __init__(self) -> None:
38+
self._lock = threading.Lock()
39+
self._allocated_ports: set[int] = set()
40+
# Recently released ports are kept in a queue to avoid immediate reuse
41+
self._recently_released: deque[int] = deque(maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN)
42+
# Register cleanup to release all ports on exit
43+
atexit.register(self.release_all)
44+
45+
def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 100) -> int:
46+
"""Allocate a free port, ensuring it's not already reserved.
47+
48+
Args:
49+
preferred_port: If provided, try to allocate this specific port first
50+
max_attempts: Maximum number of attempts to find a free port
51+
52+
Returns:
53+
An allocated port number
54+
55+
Raises:
56+
RuntimeError: If unable to find a free port after max_attempts
57+
58+
"""
59+
with self._lock:
60+
# If a preferred port is specified and available, use it
61+
if (
62+
preferred_port is not None
63+
and preferred_port not in self._allocated_ports
64+
and preferred_port not in self._recently_released
65+
and self._is_port_free(preferred_port)
66+
):
67+
self._allocated_ports.add(preferred_port)
68+
return preferred_port
69+
70+
# Try to find a free port
71+
for attempt in range(max_attempts):
72+
port = self._find_free_port()
73+
74+
# Skip ports that were recently released to avoid TIME_WAIT conflicts
75+
if port in self._recently_released:
76+
continue
77+
78+
if port not in self._allocated_ports:
79+
self._allocated_ports.add(port)
80+
return port
81+
82+
raise RuntimeError(
83+
f"Failed to allocate a free port after {max_attempts} attempts. "
84+
f"Currently allocated ports: {len(self._allocated_ports)}"
85+
)
86+
87+
def release_port(self, port: int) -> None:
88+
"""Release a previously allocated port.
89+
90+
Args:
91+
port: Port number to release
92+
93+
"""
94+
with self._lock:
95+
if port in self._allocated_ports:
96+
self._allocated_ports.remove(port)
97+
# Add to the back of the queue; oldest will be evicted when queue is full
98+
self._recently_released.append(port)
99+
100+
def release_all(self) -> None:
101+
"""Release all allocated ports."""
102+
with self._lock:
103+
self._allocated_ports.clear()
104+
self._recently_released.clear()
105+
106+
def reserve_existing_port(self, port: int) -> bool:
107+
"""Reserve a port that was allocated externally.
108+
109+
Args:
110+
port: The externally assigned port to reserve.
111+
112+
Returns:
113+
True if the port was reserved (or already reserved), False if the port value is invalid.
114+
115+
"""
116+
if port <= 0 or port > 65535:
117+
return False
118+
119+
with self._lock:
120+
if port in self._allocated_ports:
121+
return True
122+
123+
# Remove from recently released queue if present (we're explicitly reserving it)
124+
if port in self._recently_released:
125+
# Create a new deque without this port
126+
self._recently_released = deque(
127+
(p for p in self._recently_released if p != port), maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN
128+
)
129+
130+
self._allocated_ports.add(port)
131+
return True
132+
133+
@contextmanager
134+
def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]:
135+
"""Context manager for automatic port cleanup.
136+
137+
Usage:
138+
with manager.allocated_port() as port:
139+
# Use port here
140+
pass
141+
# Port automatically released
142+
143+
Args:
144+
preferred_port: Optional preferred port number
145+
146+
Yields:
147+
Allocated port number
148+
149+
"""
150+
port = self.allocate_port(preferred_port=preferred_port)
151+
try:
152+
yield port
153+
finally:
154+
self.release_port(port)
155+
156+
@staticmethod
157+
def _find_free_port() -> int:
158+
"""Find a free port using OS allocation.
159+
160+
Returns:
161+
A port number that was free at the time of checking
162+
163+
"""
164+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
165+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
166+
s.bind(("", 0))
167+
port = s.getsockname()[1]
168+
s.close()
169+
return port
170+
171+
@staticmethod
172+
def _is_port_free(port: int) -> bool:
173+
"""Check if a specific port is available.
174+
175+
Args:
176+
port: Port number to check
177+
178+
Returns:
179+
True if the port is free, False otherwise
180+
181+
"""
182+
try:
183+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
184+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
185+
s.bind(("", port))
186+
s.close()
187+
return True
188+
except OSError:
189+
return False
190+
191+
192+
# Global singleton instance
193+
_port_manager: Optional[PortManager] = None
194+
_port_manager_lock = threading.Lock()
195+
196+
197+
def get_port_manager() -> PortManager:
198+
"""Get or create the global port manager instance.
199+
200+
Returns:
201+
The global PortManager singleton
202+
203+
"""
204+
global _port_manager
205+
if _port_manager is None:
206+
with _port_manager_lock:
207+
if _port_manager is None:
208+
_port_manager = PortManager()
209+
return _port_manager

tests/tests_fabric/conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import contextlib
1415
import os
1516
import sys
1617
import threading
@@ -77,9 +78,31 @@ def restore_env_variables():
7778
@pytest.fixture(autouse=True)
7879
def teardown_process_group():
7980
"""Ensures that the distributed process group gets closed before the next test runs."""
81+
from lightning.fabric.utilities.port_manager import get_port_manager
82+
83+
# Record the port used in this test (if any)
84+
port_to_release = None
85+
if "MASTER_PORT" in os.environ:
86+
with contextlib.suppress(ValueError, KeyError):
87+
port_to_release = int(os.environ["MASTER_PORT"])
88+
8089
yield
90+
91+
# Clean up distributed connection
8192
_destroy_dist_connection()
8293

94+
manager = get_port_manager()
95+
96+
# Release the port from the manager so it can be reused
97+
if port_to_release is not None:
98+
manager.release_port(port_to_release)
99+
100+
# If the process group updated MASTER_PORT, reserve and clear it to avoid leaking between tests
101+
if "MASTER_PORT" in os.environ:
102+
with contextlib.suppress(ValueError):
103+
manager.reserve_existing_port(int(os.environ["MASTER_PORT"]))
104+
os.environ.pop("MASTER_PORT", None)
105+
83106

84107
@pytest.fixture(autouse=True)
85108
def thread_police_duuu_daaa_duuu_daaa():

tests/tests_pytorch/conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import contextlib
1415
import os
1516
import signal
1617
import sys
@@ -127,9 +128,31 @@ def restore_signal_handlers():
127128
@pytest.fixture(autouse=True)
128129
def teardown_process_group():
129130
"""Ensures that the distributed process group gets closed before the next test runs."""
131+
from lightning.fabric.utilities.port_manager import get_port_manager
132+
133+
# Record the port used in this test (if any)
134+
port_to_release = None
135+
if "MASTER_PORT" in os.environ:
136+
with contextlib.suppress(ValueError, KeyError):
137+
port_to_release = int(os.environ["MASTER_PORT"])
138+
130139
yield
140+
141+
# Clean up distributed connection
131142
_destroy_dist_connection()
132143

144+
manager = get_port_manager()
145+
146+
# Release the port from the manager so it can be reused
147+
if port_to_release is not None:
148+
manager.release_port(port_to_release)
149+
150+
# If the process group updated MASTER_PORT, reserve and clear it to avoid leaking between tests
151+
if "MASTER_PORT" in os.environ:
152+
with contextlib.suppress(ValueError):
153+
manager.reserve_existing_port(int(os.environ["MASTER_PORT"]))
154+
os.environ.pop("MASTER_PORT", None)
155+
133156

134157
@pytest.fixture(autouse=True)
135158
def reset_deterministic_algorithm():

0 commit comments

Comments
 (0)