| 
 | 1 | +# Copyright The Lightning AI team.  | 
 | 2 | +#  | 
 | 3 | +# Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 4 | +# you may not use this file except in compliance with the License.  | 
 | 5 | +# You may obtain a copy of the License at  | 
 | 6 | +#  | 
 | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | +#  | 
 | 9 | +# Unless required by applicable law or agreed to in writing, software  | 
 | 10 | +# distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 12 | +# See the License for the specific language governing permissions and  | 
 | 13 | +# limitations under the License.  | 
 | 14 | +"""Port allocation manager to prevent race conditions in distributed training."""  | 
 | 15 | + | 
 | 16 | +import atexit  | 
 | 17 | +import logging  | 
 | 18 | +import socket  | 
 | 19 | +import threading  | 
 | 20 | +from collections import deque  | 
 | 21 | +from collections.abc import Iterator  | 
 | 22 | +from contextlib import contextmanager  | 
 | 23 | +from typing import Optional  | 
 | 24 | + | 
 | 25 | +log = logging.getLogger(__name__)  | 
 | 26 | + | 
 | 27 | +# Size of the recently released ports queue  | 
 | 28 | +# This prevents immediate reuse of ports that were just released  | 
 | 29 | +# Set to 1024 to balance memory usage vs TIME_WAIT protection  | 
 | 30 | +_RECENTLY_RELEASED_PORTS_MAXLEN = 1024  | 
 | 31 | + | 
 | 32 | + | 
 | 33 | +class PortManager:  | 
 | 34 | +    """Thread-safe port manager to prevent EADDRINUSE errors.  | 
 | 35 | +
  | 
 | 36 | +    This manager maintains a global registry of allocated ports to ensure that multiple concurrent tests don't try to  | 
 | 37 | +    use the same port. While this doesn't completely eliminate the race condition with external processes, it prevents  | 
 | 38 | +    internal collisions within the test suite.  | 
 | 39 | +
  | 
 | 40 | +    """  | 
 | 41 | + | 
 | 42 | +    def __init__(self) -> None:  | 
 | 43 | +        self._lock = threading.Lock()  | 
 | 44 | +        self._allocated_ports: set[int] = set()  | 
 | 45 | +        # Recently released ports are kept in a queue to avoid immediate reuse  | 
 | 46 | +        self._recently_released: deque[int] = deque(maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN)  | 
 | 47 | +        # Register cleanup to release all ports on exit  | 
 | 48 | +        atexit.register(self.release_all)  | 
 | 49 | + | 
 | 50 | +    def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 1000) -> int:  | 
 | 51 | +        """Allocate a free port, ensuring it's not already reserved.  | 
 | 52 | +
  | 
 | 53 | +        Args:  | 
 | 54 | +            preferred_port: If provided, try to allocate this specific port first  | 
 | 55 | +            max_attempts: Maximum number of attempts to find a free port  | 
 | 56 | +
  | 
 | 57 | +        Returns:  | 
 | 58 | +            An allocated port number  | 
 | 59 | +
  | 
 | 60 | +        Raises:  | 
 | 61 | +            RuntimeError: If unable to find a free port after max_attempts  | 
 | 62 | +
  | 
 | 63 | +        """  | 
 | 64 | +        with self._lock:  | 
 | 65 | +            # If a preferred port is specified and available, use it  | 
 | 66 | +            if (  | 
 | 67 | +                preferred_port is not None  | 
 | 68 | +                and preferred_port not in self._allocated_ports  | 
 | 69 | +                and preferred_port not in self._recently_released  | 
 | 70 | +                and self._is_port_free(preferred_port)  | 
 | 71 | +            ):  | 
 | 72 | +                self._allocated_ports.add(preferred_port)  | 
 | 73 | +                return preferred_port  | 
 | 74 | + | 
 | 75 | +            # Let the OS choose a free port, but verify it's not in our tracking structures  | 
 | 76 | +            # The OS naturally avoids ports in TIME_WAIT (without SO_REUSEADDR)  | 
 | 77 | +            for attempt in range(max_attempts):  | 
 | 78 | +                port = self._find_free_port()  | 
 | 79 | + | 
 | 80 | +                # Skip if already allocated by us or recently released  | 
 | 81 | +                # This prevents race conditions within our process  | 
 | 82 | +                if port not in self._allocated_ports and port not in self._recently_released:  | 
 | 83 | +                    self._allocated_ports.add(port)  | 
 | 84 | + | 
 | 85 | +                    # Log diagnostics if queue utilization is high  | 
 | 86 | +                    queue_count = len(self._recently_released)  | 
 | 87 | +                    if queue_count > _RECENTLY_RELEASED_PORTS_MAXLEN * 0.8:  # >80% full  | 
 | 88 | +                        log.warning(  | 
 | 89 | +                            f"Port queue utilization high: {queue_count}/{_RECENTLY_RELEASED_PORTS_MAXLEN} "  | 
 | 90 | +                            f"({queue_count / _RECENTLY_RELEASED_PORTS_MAXLEN * 100:.1f}% full). "  | 
 | 91 | +                            f"Allocated port {port}. Active allocations: {len(self._allocated_ports)}"  | 
 | 92 | +                        )  | 
 | 93 | + | 
 | 94 | +                    return port  | 
 | 95 | + | 
 | 96 | +            # Provide detailed diagnostics to understand allocation failures  | 
 | 97 | +            allocated_count = len(self._allocated_ports)  | 
 | 98 | +            queue_count = len(self._recently_released)  | 
 | 99 | +            queue_capacity = _RECENTLY_RELEASED_PORTS_MAXLEN  | 
 | 100 | +            queue_utilization = (queue_count / queue_capacity * 100) if queue_capacity > 0 else 0  | 
 | 101 | + | 
 | 102 | +            raise RuntimeError(  | 
 | 103 | +                f"Failed to allocate a free port after {max_attempts} attempts. "  | 
 | 104 | +                f"Diagnostics: allocated={allocated_count}, "  | 
 | 105 | +                f"recently_released={queue_count}/{queue_capacity} ({queue_utilization:.1f}% full). "  | 
 | 106 | +                f"If queue is near capacity, consider increasing _RECENTLY_RELEASED_PORTS_MAXLEN."  | 
 | 107 | +            )  | 
 | 108 | + | 
 | 109 | +    def release_port(self, port: int) -> None:  | 
 | 110 | +        """Release a previously allocated port.  | 
 | 111 | +
  | 
 | 112 | +        Args:  | 
 | 113 | +            port: Port number to release  | 
 | 114 | +
  | 
 | 115 | +        """  | 
 | 116 | +        with self._lock:  | 
 | 117 | +            if port in self._allocated_ports:  | 
 | 118 | +                self._allocated_ports.remove(port)  | 
 | 119 | +                # Add to the back of the queue; oldest will be evicted when queue is full  | 
 | 120 | +                self._recently_released.append(port)  | 
 | 121 | + | 
 | 122 | +    def release_all(self) -> None:  | 
 | 123 | +        """Release all allocated ports."""  | 
 | 124 | +        with self._lock:  | 
 | 125 | +            self._allocated_ports.clear()  | 
 | 126 | +            self._recently_released.clear()  | 
 | 127 | + | 
 | 128 | +    def reserve_existing_port(self, port: int) -> bool:  | 
 | 129 | +        """Reserve a port that was allocated externally.  | 
 | 130 | +
  | 
 | 131 | +        Args:  | 
 | 132 | +            port: The externally assigned port to reserve.  | 
 | 133 | +
  | 
 | 134 | +        Returns:  | 
 | 135 | +            True if the port was reserved (or already reserved), False if the port value is invalid.  | 
 | 136 | +
  | 
 | 137 | +        """  | 
 | 138 | +        if port <= 0 or port > 65535:  | 
 | 139 | +            return False  | 
 | 140 | + | 
 | 141 | +        with self._lock:  | 
 | 142 | +            if port in self._allocated_ports:  | 
 | 143 | +                return True  | 
 | 144 | + | 
 | 145 | +            # Remove from recently released queue if present (we're explicitly reserving it)  | 
 | 146 | +            if port in self._recently_released:  | 
 | 147 | +                # Create a new deque without this port  | 
 | 148 | +                self._recently_released = deque(  | 
 | 149 | +                    (p for p in self._recently_released if p != port), maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN  | 
 | 150 | +                )  | 
 | 151 | + | 
 | 152 | +            self._allocated_ports.add(port)  | 
 | 153 | +            return True  | 
 | 154 | + | 
 | 155 | +    @contextmanager  | 
 | 156 | +    def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]:  | 
 | 157 | +        """Context manager for automatic port cleanup.  | 
 | 158 | +
  | 
 | 159 | +        Usage:  | 
 | 160 | +            with manager.allocated_port() as port:  | 
 | 161 | +                # Use port here  | 
 | 162 | +                pass  | 
 | 163 | +            # Port automatically released  | 
 | 164 | +
  | 
 | 165 | +        Args:  | 
 | 166 | +            preferred_port: Optional preferred port number  | 
 | 167 | +
  | 
 | 168 | +        Yields:  | 
 | 169 | +            Allocated port number  | 
 | 170 | +
  | 
 | 171 | +        """  | 
 | 172 | +        port = self.allocate_port(preferred_port=preferred_port)  | 
 | 173 | +        try:  | 
 | 174 | +            yield port  | 
 | 175 | +        finally:  | 
 | 176 | +            self.release_port(port)  | 
 | 177 | + | 
 | 178 | +    @staticmethod  | 
 | 179 | +    def _find_free_port() -> int:  | 
 | 180 | +        """Find a free port using OS allocation.  | 
 | 181 | +
  | 
 | 182 | +        Returns:  | 
 | 183 | +            A port number that was free at the time of checking  | 
 | 184 | +
  | 
 | 185 | +        """  | 
 | 186 | +        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)  | 
 | 187 | +        # Don't use SO_REUSEADDR - we need to match the behavior of TCPStore  | 
 | 188 | +        # which binds without it, so ports in TIME_WAIT will be rejected  | 
 | 189 | +        s.bind(("", 0))  | 
 | 190 | +        port = s.getsockname()[1]  | 
 | 191 | +        s.close()  | 
 | 192 | +        return port  | 
 | 193 | + | 
 | 194 | +    @staticmethod  | 
 | 195 | +    def _is_port_free(port: int) -> bool:  | 
 | 196 | +        """Check if a specific port is available.  | 
 | 197 | +
  | 
 | 198 | +        Args:  | 
 | 199 | +            port: Port number to check  | 
 | 200 | +
  | 
 | 201 | +        Returns:  | 
 | 202 | +            True if the port is free, False otherwise  | 
 | 203 | +
  | 
 | 204 | +        """  | 
 | 205 | +        try:  | 
 | 206 | +            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)  | 
 | 207 | +            # Don't use SO_REUSEADDR - we need to match the behavior of TCPStore  | 
 | 208 | +            # which binds without it, so ports in TIME_WAIT will be rejected  | 
 | 209 | +            s.bind(("", port))  | 
 | 210 | +            s.close()  | 
 | 211 | +            return True  | 
 | 212 | +        except OSError:  | 
 | 213 | +            return False  | 
 | 214 | + | 
 | 215 | + | 
 | 216 | +# Global singleton instance  | 
 | 217 | +_port_manager: Optional[PortManager] = None  | 
 | 218 | +_port_manager_lock = threading.Lock()  | 
 | 219 | + | 
 | 220 | + | 
 | 221 | +def get_port_manager() -> PortManager:  | 
 | 222 | +    """Get or create the global port manager instance.  | 
 | 223 | +
  | 
 | 224 | +    Returns:  | 
 | 225 | +        The global PortManager singleton  | 
 | 226 | +
  | 
 | 227 | +    """  | 
 | 228 | +    global _port_manager  | 
 | 229 | +    if _port_manager is None:  | 
 | 230 | +        with _port_manager_lock:  | 
 | 231 | +            if _port_manager is None:  | 
 | 232 | +                _port_manager = PortManager()  | 
 | 233 | +    return _port_manager  | 
0 commit comments