diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 0e1cc944a3492..1db30fb489b47 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309)) --- diff --git a/src/lightning/fabric/plugins/environments/lightning.py b/src/lightning/fabric/plugins/environments/lightning.py index a89d840cb0812..7f83a8527089e 100644 --- a/src/lightning/fabric/plugins/environments/lightning.py +++ b/src/lightning/fabric/plugins/environments/lightning.py @@ -13,11 +13,11 @@ # limitations under the License. import os -import socket from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.fabric.utilities.port_manager import get_port_manager from lightning.fabric.utilities.rank_zero import rank_zero_only @@ -104,6 +104,13 @@ def teardown(self) -> None: if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] + if self._main_port != -1: + get_port_manager().release_port(self._main_port) + self._main_port = -1 + + os.environ.pop("MASTER_PORT", None) + os.environ.pop("MASTER_ADDR", None) + def find_free_network_port() -> int: """Finds a free port on localhost. @@ -111,9 +118,24 @@ def find_free_network_port() -> int: It is useful in single-node training when we don't want to connect to a real main node but have to set the `MASTER_PORT` environment variable. + The allocated port is reserved and won't be returned by subsequent calls until it's explicitly released. + + Returns: + A port number that is reserved and free at the time of allocation + """ - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - port = s.getsockname()[1] - s.close() - return port + # If an external launcher already specified a MASTER_PORT (for example, torch.distributed.spawn or + # multiprocessing helpers), reserve it through the port manager so no other test reuses the same number. + if "MASTER_PORT" in os.environ: + master_port_str = os.environ["MASTER_PORT"] + try: + existing_port = int(master_port_str) + except ValueError: + pass + else: + port_manager = get_port_manager() + if port_manager.reserve_existing_port(existing_port): + return existing_port + + port_manager = get_port_manager() + return port_manager.allocate_port() diff --git a/src/lightning/fabric/utilities/port_manager.py b/src/lightning/fabric/utilities/port_manager.py new file mode 100644 index 0000000000000..cdf19605023d5 --- /dev/null +++ b/src/lightning/fabric/utilities/port_manager.py @@ -0,0 +1,233 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Port allocation manager to prevent race conditions in distributed training.""" + +import atexit +import logging +import socket +import threading +from collections import deque +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Optional + +log = logging.getLogger(__name__) + +# Size of the recently released ports queue +# This prevents immediate reuse of ports that were just released +# Set to 1024 to balance memory usage vs TIME_WAIT protection +_RECENTLY_RELEASED_PORTS_MAXLEN = 1024 + + +class PortManager: + """Thread-safe port manager to prevent EADDRINUSE errors. + + This manager maintains a global registry of allocated ports to ensure that multiple concurrent tests don't try to + use the same port. While this doesn't completely eliminate the race condition with external processes, it prevents + internal collisions within the test suite. + + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._allocated_ports: set[int] = set() + # Recently released ports are kept in a queue to avoid immediate reuse + self._recently_released: deque[int] = deque(maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN) + # Register cleanup to release all ports on exit + atexit.register(self.release_all) + + def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 1000) -> int: + """Allocate a free port, ensuring it's not already reserved. + + Args: + preferred_port: If provided, try to allocate this specific port first + max_attempts: Maximum number of attempts to find a free port + + Returns: + An allocated port number + + Raises: + RuntimeError: If unable to find a free port after max_attempts + + """ + with self._lock: + # If a preferred port is specified and available, use it + if ( + preferred_port is not None + and preferred_port not in self._allocated_ports + and preferred_port not in self._recently_released + and self._is_port_free(preferred_port) + ): + self._allocated_ports.add(preferred_port) + return preferred_port + + # Let the OS choose a free port, but verify it's not in our tracking structures + # The OS naturally avoids ports in TIME_WAIT (without SO_REUSEADDR) + for attempt in range(max_attempts): + port = self._find_free_port() + + # Skip if already allocated by us or recently released + # This prevents race conditions within our process + if port not in self._allocated_ports and port not in self._recently_released: + self._allocated_ports.add(port) + + # Log diagnostics if queue utilization is high + queue_count = len(self._recently_released) + if queue_count > _RECENTLY_RELEASED_PORTS_MAXLEN * 0.8: # >80% full + log.warning( + f"Port queue utilization high: {queue_count}/{_RECENTLY_RELEASED_PORTS_MAXLEN} " + f"({queue_count / _RECENTLY_RELEASED_PORTS_MAXLEN * 100:.1f}% full). " + f"Allocated port {port}. Active allocations: {len(self._allocated_ports)}" + ) + + return port + + # Provide detailed diagnostics to understand allocation failures + allocated_count = len(self._allocated_ports) + queue_count = len(self._recently_released) + queue_capacity = _RECENTLY_RELEASED_PORTS_MAXLEN + queue_utilization = (queue_count / queue_capacity * 100) if queue_capacity > 0 else 0 + + raise RuntimeError( + f"Failed to allocate a free port after {max_attempts} attempts. " + f"Diagnostics: allocated={allocated_count}, " + f"recently_released={queue_count}/{queue_capacity} ({queue_utilization:.1f}% full). " + f"If queue is near capacity, consider increasing _RECENTLY_RELEASED_PORTS_MAXLEN." + ) + + def release_port(self, port: int) -> None: + """Release a previously allocated port. + + Args: + port: Port number to release + + """ + with self._lock: + if port in self._allocated_ports: + self._allocated_ports.remove(port) + # Add to the back of the queue; oldest will be evicted when queue is full + self._recently_released.append(port) + + def release_all(self) -> None: + """Release all allocated ports.""" + with self._lock: + self._allocated_ports.clear() + self._recently_released.clear() + + def reserve_existing_port(self, port: int) -> bool: + """Reserve a port that was allocated externally. + + Args: + port: The externally assigned port to reserve. + + Returns: + True if the port was reserved (or already reserved), False if the port value is invalid. + + """ + if port <= 0 or port > 65535: + return False + + with self._lock: + if port in self._allocated_ports: + return True + + # Remove from recently released queue if present (we're explicitly reserving it) + if port in self._recently_released: + # Create a new deque without this port + self._recently_released = deque( + (p for p in self._recently_released if p != port), maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN + ) + + self._allocated_ports.add(port) + return True + + @contextmanager + def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]: + """Context manager for automatic port cleanup. + + Usage: + with manager.allocated_port() as port: + # Use port here + pass + # Port automatically released + + Args: + preferred_port: Optional preferred port number + + Yields: + Allocated port number + + """ + port = self.allocate_port(preferred_port=preferred_port) + try: + yield port + finally: + self.release_port(port) + + @staticmethod + def _find_free_port() -> int: + """Find a free port using OS allocation. + + Returns: + A port number that was free at the time of checking + + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Don't use SO_REUSEADDR - we need to match the behavior of TCPStore + # which binds without it, so ports in TIME_WAIT will be rejected + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + @staticmethod + def _is_port_free(port: int) -> bool: + """Check if a specific port is available. + + Args: + port: Port number to check + + Returns: + True if the port is free, False otherwise + + """ + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Don't use SO_REUSEADDR - we need to match the behavior of TCPStore + # which binds without it, so ports in TIME_WAIT will be rejected + s.bind(("", port)) + s.close() + return True + except OSError: + return False + + +# Global singleton instance +_port_manager: Optional[PortManager] = None +_port_manager_lock = threading.Lock() + + +def get_port_manager() -> PortManager: + """Get or create the global port manager instance. + + Returns: + The global PortManager singleton + + """ + global _port_manager + if _port_manager is None: + with _port_manager_lock: + if _port_manager is None: + _port_manager = PortManager() + return _port_manager diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 9d4a0b9462f2e..a06bb0eacdbb4 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os import sys import threading @@ -77,9 +78,25 @@ def restore_env_variables(): @pytest.fixture(autouse=True) def teardown_process_group(): """Ensures that the distributed process group gets closed before the next test runs.""" + import os + + from lightning.fabric.utilities.port_manager import get_port_manager + yield + + # Clean up distributed connection _destroy_dist_connection() + manager = get_port_manager() + + # If a process group created or updated MASTER_PORT during the test, reserve it and then clear it + if "MASTER_PORT" in os.environ: + with contextlib.suppress(ValueError): + port = int(os.environ["MASTER_PORT"]) + manager.reserve_existing_port(port) + manager.release_port(port) + os.environ.pop("MASTER_PORT", None) + @pytest.fixture(autouse=True) def thread_police_duuu_daaa_duuu_daaa(): @@ -204,6 +221,67 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" +def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None: + """Retry tests that fail with EADDRINUSE errors. + + This handles the race condition where a port might enter TIME_WAIT between our port allocation check and actual + binding by TCPStore. + + """ + if call.excinfo is not None and call.when == "call": + exception_msg = str(call.excinfo.value) + exception_type = str(type(call.excinfo.value).__name__) + # Check if this is an EADDRINUSE error from distributed training + # Catch both direct EADDRINUSE errors and DistNetworkError which wraps them + if ( + "EADDRINUSE" in exception_msg + or "address already in use" in exception_msg.lower() + or "DistNetworkError" in exception_type + ): + # Get the retry count from the test node + retry_count = getattr(item, "_port_retry_count", 0) + max_retries = 3 + + if retry_count < max_retries: + # Increment retry counter + item._port_retry_count = retry_count + 1 + + # Log the retry + if hasattr(item.config, "get_terminal_writer"): + writer = item.config.get_terminal_writer() + writer.write( + f"\n[Port conflict detected] Retrying test {item.name} " + f"(attempt {retry_count + 2}/{max_retries + 1})...\n", + yellow=True, + ) + + # Clear the port manager's state to get fresh ports + from lightning.fabric.utilities.port_manager import get_port_manager + + manager = get_port_manager() + manager.release_all() + + # Clear MASTER_PORT so cluster environment allocates a fresh port on retry + import os + + os.environ.pop("MASTER_PORT", None) + + # Re-run the test by raising Rerun exception + # Note: This requires pytest-rerunfailures plugin + import time + + time.sleep(1.0) # Wait for OS to release ports from TIME_WAIT state + + # If pytest-rerunfailures is available, use it + try: + from pytest_rerunfailures import Rerun + + raise Rerun(f"Port conflict (EADDRINUSE), retry {retry_count + 1}/{max_retries}") + except ImportError: + # Plugin not available, just let the test fail + pass + + def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: """An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`""" initial_size = len(items) diff --git a/tests/tests_fabric/plugins/environments/test_lightning.py b/tests/tests_fabric/plugins/environments/test_lightning.py index cc0179af4c3f5..aa225475289f8 100644 --- a/tests/tests_fabric/plugins/environments/test_lightning.py +++ b/tests/tests_fabric/plugins/environments/test_lightning.py @@ -17,6 +17,7 @@ import pytest from lightning.fabric.plugins.environments import LightningEnvironment +from lightning.fabric.utilities.port_manager import get_port_manager @mock.patch.dict(os.environ, {}, clear=True) @@ -84,5 +85,18 @@ def test_teardown(): assert "WORLD_SIZE" not in os.environ +@mock.patch.dict(os.environ, {}, clear=True) +def test_teardown_releases_port_and_env(): + env = LightningEnvironment() + port = env.main_port + assert port in get_port_manager()._allocated_ports + + env.teardown() + + assert port not in get_port_manager()._allocated_ports + assert "MASTER_PORT" not in os.environ + assert "MASTER_ADDR" not in os.environ + + def test_detect(): assert LightningEnvironment.detect() diff --git a/tests/tests_fabric/utilities/test_port_manager.py b/tests/tests_fabric/utilities/test_port_manager.py new file mode 100644 index 0000000000000..8ea820baa9a42 --- /dev/null +++ b/tests/tests_fabric/utilities/test_port_manager.py @@ -0,0 +1,776 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the PortManager utility and port allocation integration.""" + +import os +import socket +import threading +from collections import Counter + +import pytest + +from lightning.fabric.plugins.environments.lightning import find_free_network_port +from lightning.fabric.utilities.port_manager import PortManager, get_port_manager + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def with_master_port(): + """Fixture that sets MASTER_PORT before test runs, for conftest coverage.""" + port = find_free_network_port() + previous_value = os.environ.get("MASTER_PORT") + os.environ["MASTER_PORT"] = str(port) + try: + yield port + finally: + if previous_value is None: + os.environ.pop("MASTER_PORT", None) + else: + os.environ["MASTER_PORT"] = previous_value + + +@pytest.fixture +def with_invalid_master_port(): + """Fixture that sets invalid MASTER_PORT to test error handling.""" + previous_value = os.environ.get("MASTER_PORT") + os.environ["MASTER_PORT"] = "not_a_valid_port_number" + try: + yield + finally: + if previous_value is None: + os.environ.pop("MASTER_PORT", None) + else: + os.environ["MASTER_PORT"] = previous_value + + +# ============================================================================= +# Unit Tests for PortManager +# ============================================================================= + + +def test_port_manager_allocates_unique_ports(): + """Test that PortManager allocates unique ports.""" + manager = PortManager() + + # Allocate multiple ports + ports = [manager.allocate_port() for _ in range(10)] + + # All ports should be unique + assert len(ports) == len(set(ports)), f"Duplicate ports found: {ports}" + + # All ports should be valid (>= 1024) + assert all(p >= 1024 for p in ports), "Some ports are in reserved range" + + +def test_port_manager_release_port(): + """Test that released ports are removed from the allocated set.""" + manager = PortManager() + + # Allocate a port + port = manager.allocate_port() + assert port in manager._allocated_ports + + # Release the port + manager.release_port(port) + assert port not in manager._allocated_ports + + +def test_port_manager_release_all(): + """Test that release_all clears all allocated ports.""" + manager = PortManager() + + # Allocate multiple ports + [manager.allocate_port() for _ in range(5)] + assert len(manager._allocated_ports) == 5 + + # Release all + manager.release_all() + assert len(manager._allocated_ports) == 0 + + +def test_port_manager_release_nonexistent_port(): + """Test that releasing a non-existent port doesn't cause errors.""" + manager = PortManager() + + # Try to release a port that was never allocated + manager.release_port(12345) # Should not raise an error + + # Verify nothing broke + port = manager.allocate_port() + assert port >= 1024 + + +def test_port_manager_thread_safety(): + """Test that PortManager is thread-safe under concurrent access.""" + manager = PortManager() + ports = [] + lock = threading.Lock() + + def allocate_ports(): + """Allocate multiple ports from different threads.""" + for _ in range(10): + port = manager.allocate_port() + with lock: + ports.append(port) + + # Create multiple threads that allocate ports concurrently + threads = [threading.Thread(target=allocate_ports) for _ in range(10)] + + # Start all threads + for thread in threads: + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify we got 100 unique ports (10 threads × 10 ports each) + assert len(ports) == 100, f"Expected 100 ports, got {len(ports)}" + assert len(set(ports)) == 100, f"Expected 100 unique ports, got {len(set(ports))}" + + # Check for any duplicates + counts = Counter(ports) + duplicates = {port: count for port, count in counts.items() if count > 1} + assert not duplicates, f"Found duplicate ports: {duplicates}" + + +def test_port_manager_preferred_port(): + """Test that PortManager can allocate a preferred port if available.""" + manager = PortManager() + + # Try to find a free port first + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + free_port = s.getsockname()[1] + s.close() + + # Allocate the preferred port + allocated = manager.allocate_port(preferred_port=free_port) + assert allocated == free_port + + # Try to allocate the same preferred port again (should get a different one) + allocated2 = manager.allocate_port(preferred_port=free_port) + assert allocated2 != free_port + + +def test_port_manager_allocation_failure(): + """Test that PortManager raises error when unable to allocate after max attempts.""" + manager = PortManager() + + # Pre-allocate a large number of ports to make it harder to find a free one + # Then try with max_attempts=1 which should fail quickly + allocated_ports = [manager.allocate_port() for _ in range(50)] + + # Test that it can still allocate with enough attempts + port = manager.allocate_port(max_attempts=100) + assert port >= 1024 + + # Clean up + for p in allocated_ports: + manager.release_port(p) + manager.release_port(port) + + +def test_port_manager_prevents_reallocation(): + """Test that a port won't be allocated twice until released.""" + manager = PortManager() + + # Allocate a port + port1 = manager.allocate_port() + + # Allocate many more ports - none should match port1 + more_ports = [manager.allocate_port() for _ in range(50)] + + # port1 should not appear in more_ports + assert port1 not in more_ports, f"Port {port1} was reallocated before release" + + # After releasing port1, we should eventually be able to get it again + # (though not guaranteed due to OS port allocation) + manager.release_port(port1) + assert port1 not in manager._allocated_ports + + # Clean up + for port in more_ports: + manager.release_port(port) + + +def test_get_port_manager_singleton(): + """Test that get_port_manager returns the same instance.""" + manager1 = get_port_manager() + manager2 = get_port_manager() + + # Should be the same instance + assert manager1 is manager2 + + # Allocating from one should be visible in the other + port = manager1.allocate_port() + assert port in manager2._allocated_ports + + # Clean up + manager1.release_port(port) + + +def test_get_port_manager_thread_safe_singleton(): + """Test that get_port_manager creates singleton safely across threads.""" + managers = [] + lock = threading.Lock() + + def get_manager(): + manager = get_port_manager() + with lock: + managers.append(manager) + + # Create multiple threads that get the port manager + threads = [threading.Thread(target=get_manager) for _ in range(20)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # All should be the same instance + assert len(managers) == 20 + assert all(m is managers[0] for m in managers), "get_port_manager returned different instances" + + +def test_port_manager_is_port_free(): + """Test the _is_port_free helper method.""" + manager = PortManager() + + # Find a free port using OS + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + free_port = s.getsockname()[1] + + # Port should be reported as not free while socket is open + assert not manager._is_port_free(free_port) + + # Close the socket + s.close() + + # Now port should be free (though there's still a small race condition) + # We'll skip this check as it's unreliable due to OS behavior + + +def test_port_manager_find_free_port(): + """Test the _find_free_port helper method.""" + manager = PortManager() + + # Should return a valid port + port = manager._find_free_port() + assert isinstance(port, int) + assert port >= 1024 + assert port <= 65535 + + +def test_port_manager_concurrent_allocation_and_release(): + """Test concurrent allocation and release operations.""" + manager = PortManager() + ports = [] + lock = threading.Lock() + active_ports: set[int] = set() + + def allocate_and_release(): + for _ in range(5): + # Allocate a port + port = manager.allocate_port() + with lock: + assert port not in active_ports, "Port allocated concurrently" + active_ports.add(port) + ports.append(port) + + # Release it immediately + manager.release_port(port) + with lock: + active_ports.remove(port) + + # Run multiple threads + threads = [threading.Thread(target=allocate_and_release) for _ in range(10)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Should have allocated 50 ports total (10 threads × 5 ports) + assert len(ports) == 50 + + # After all releases, manager should have no ports allocated + assert len(manager._allocated_ports) == 0 + + +def test_port_manager_atexit_cleanup(): + """Test that PortManager registers atexit cleanup.""" + + # Create a new manager + manager = PortManager() + + # The manager should have registered release_all with atexit + # We can't easily test atexit directly, but we can verify the method exists + assert callable(manager.release_all) + + # Verify release_all works + manager.allocate_port() + manager.allocate_port() + assert len(manager._allocated_ports) == 2 + + manager.release_all() + assert len(manager._allocated_ports) == 0 + + +def test_port_manager_reserve_existing_port_free(): + """reserve_existing_port should succeed for free ports and track them.""" + manager = PortManager() + + port = manager._find_free_port() + assert manager.reserve_existing_port(port) + assert port in manager._allocated_ports + + # Second call should succeed but not duplicate + assert manager.reserve_existing_port(port) + assert len(manager._allocated_ports) == 1 + + +def test_port_manager_reserve_existing_port_invalid_value(): + """reserve_existing_port should reject invalid port numbers.""" + manager = PortManager() + + assert not manager.reserve_existing_port(0) + assert not manager.reserve_existing_port(-1) + assert not manager.reserve_existing_port(70000) + + +def test_port_manager_reserve_existing_port_after_release(): + """Ports released from sockets should become reservable.""" + manager = PortManager() + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", 0)) + reusable_port = s.getsockname()[1] + s.close() + + assert manager.reserve_existing_port(reusable_port) + assert reusable_port in manager._allocated_ports + + +def test_port_manager_context_manager(): + """Test that context manager automatically releases ports.""" + manager = PortManager() + + # Use context manager + with manager.allocated_port() as port: + # Port should be allocated + assert port in manager._allocated_ports + assert isinstance(port, int) + assert port >= 1024 + + # After context, port should be released + assert port not in manager._allocated_ports + + +def test_port_manager_context_manager_exception(): + """Test that context manager releases port even on exception.""" + manager = PortManager() + + try: + with manager.allocated_port() as port: + allocated_port = port + # Port should be allocated + assert port in manager._allocated_ports + # Raise exception + raise ValueError("Test exception") + except ValueError: + pass + + # Port should still be released despite exception + assert allocated_port not in manager._allocated_ports + + +def test_port_manager_context_manager_nested(): + """Test that nested context managers work correctly.""" + manager = PortManager() + + with manager.allocated_port() as port1: + assert port1 in manager._allocated_ports + + with manager.allocated_port() as port2: + # Both ports should be allocated + assert port1 in manager._allocated_ports + assert port2 in manager._allocated_ports + # Ports should be different + assert port1 != port2 + + # port2 should be released, port1 still allocated + assert port1 in manager._allocated_ports + assert port2 not in manager._allocated_ports + + # Both ports should now be released + assert port1 not in manager._allocated_ports + assert port2 not in manager._allocated_ports + + +# ============================================================================= +# Integration Tests for find_free_network_port() +# ============================================================================= + + +def test_find_free_network_port_uses_port_manager(): + """Test that find_free_network_port uses the PortManager.""" + manager = get_port_manager() + + # Clear any previously allocated ports + initial_count = len(manager._allocated_ports) + + # Allocate a port using the function + port = find_free_network_port() + + # The port should be in the manager's allocated set + assert port in manager._allocated_ports + assert len(manager._allocated_ports) == initial_count + 1 + + # Clean up + manager.release_port(port) + + +def test_find_free_network_port_returns_unique_ports(): + """Test that multiple calls return unique ports.""" + manager = get_port_manager() + + # Allocate multiple ports + ports = [find_free_network_port() for _ in range(10)] + + # All should be unique + assert len(ports) == len(set(ports)), f"Duplicate ports: {ports}" + + # All should be tracked by the manager + for port in ports: + assert port in manager._allocated_ports + + # Clean up + for port in ports: + manager.release_port(port) + + +def test_find_free_network_port_thread_safety(): + """Test that find_free_network_port is thread-safe.""" + ports = [] + lock = threading.Lock() + + def allocate(): + for _ in range(5): + port = find_free_network_port() + with lock: + ports.append(port) + + # Run 10 threads, each allocating 5 ports + threads = [threading.Thread(target=allocate) for _ in range(10)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Should have 50 unique ports + assert len(ports) == 50 + assert len(set(ports)) == 50, "Duplicate ports allocated across threads" + + # Clean up + manager = get_port_manager() + for port in ports: + manager.release_port(port) + + +def test_port_allocation_simulates_distributed_test_lifecycle(): + """Simulate the lifecycle of a distributed test with port allocation and release.""" + manager = get_port_manager() + initial_count = len(manager._allocated_ports) + + # Simulate test setup: allocate a port + port = find_free_network_port() + os.environ["MASTER_PORT"] = str(port) + + # Verify port is allocated + assert port in manager._allocated_ports + + # Simulate test teardown: release the port + if "MASTER_PORT" in os.environ: + port_to_release = int(os.environ["MASTER_PORT"]) + manager.release_port(port_to_release) + del os.environ["MASTER_PORT"] + + # Verify port is released + assert port not in manager._allocated_ports + assert len(manager._allocated_ports) == initial_count + + +def test_conftest_cleanup_with_master_port_set(with_master_port): + """Test conftest cleanup when MASTER_PORT is set before test starts. + + This test uses a fixture to set MASTER_PORT before the test runs, allowing the conftest teardown_process_group + fixture to capture and clean it up. This ensures the conftest cleanup code is covered. + + """ + manager = get_port_manager() + port = with_master_port # Port was set by fixture + + # Verify port is allocated + assert port in manager._allocated_ports + assert os.environ.get("MASTER_PORT") == str(port) + + # Leave MASTER_PORT set - conftest teardown will clean it up + # After this test, teardown_process_group will: + # 1. Detect MASTER_PORT in os.environ (line captured before yield) + # 2. Call get_port_manager().release_port(port) + # 3. Port gets released back to manager + + +def test_conftest_handles_invalid_master_port(with_invalid_master_port): + """Test conftest handles invalid MASTER_PORT gracefully. + + This exercises the contextlib.suppress(ValueError, KeyError) path in the conftest teardown_process_group fixture. + + """ + # Fixture set MASTER_PORT to "not_a_valid_port_number" + # The conftest will try to parse it: int(os.environ["MASTER_PORT"]) + # This will raise ValueError, which should be caught by contextlib.suppress + + # Verify the invalid value is set + assert os.environ.get("MASTER_PORT") == "not_a_valid_port_number" + + # This test just needs to complete without crashing + # The conftest teardown will handle the ValueError gracefully + + +def test_multiple_tests_can_reuse_ports_after_release(): + """Test that ports can be reused after being released.""" + manager = get_port_manager() + + # First "test" allocates a port + port1 = find_free_network_port() + assert port1 in manager._allocated_ports + + # First "test" completes and releases the port + manager.release_port(port1) + assert port1 not in manager._allocated_ports + + # Second "test" allocates ports (may or may not get the same port) + port2 = find_free_network_port() + assert port2 in manager._allocated_ports + + # Ports should be valid regardless + assert port1 >= 1024 + assert port2 >= 1024 + + # Clean up + manager.release_port(port2) + + +def test_concurrent_tests_dont_get_same_port(): + """Test that concurrent tests never receive the same port.""" + manager = get_port_manager() + ports_per_thread = [] + lock = threading.Lock() + + def simulate_test(): + """Simulate a test that allocates a port, uses it, then releases it.""" + my_ports = [] + + # Allocate port for this "test" + port = find_free_network_port() + my_ports.append(port) + + # Simulate some work + import time + + time.sleep(0.001) + + # Release port after "test" completes + manager.release_port(port) + + with lock: + ports_per_thread.append(my_ports) + + # Run 20 concurrent "tests" + threads = [threading.Thread(target=simulate_test) for _ in range(20)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Collect all ports that were allocated + all_ports = [port for thread_ports in ports_per_thread for port in thread_ports] + + # All ports should have been unique at the time of allocation + assert len(all_ports) == 20 + assert len(set(all_ports)) == 20, "Some concurrent tests got the same port!" + + +def test_port_manager_survives_multiple_test_sessions(): + """Test that the port manager maintains state across multiple test sessions.""" + manager = get_port_manager() + + # Session 1: Allocate some ports + session1_ports = [find_free_network_port() for _ in range(3)] + + # Session 2: Allocate more ports (should not overlap with session 1) + session2_ports = [find_free_network_port() for _ in range(3)] + + # No overlap between sessions while both are active + assert not set(session1_ports) & set(session2_ports) + + # Release session 1 ports + for port in session1_ports: + manager.release_port(port) + + # Session 3: Can allocate more ports + session3_ports = [find_free_network_port() for _ in range(3)] + + # Session 3 shouldn't overlap with active session 2 + assert not set(session2_ports) & set(session3_ports) + + # Clean up + for port in session2_ports + session3_ports: + manager.release_port(port) + + +def test_port_manager_allocation_runtime_error(): + """Test that allocation fails gracefully when max_attempts is exhausted.""" + manager = PortManager() + + # Allocate a port first + allocated_port = manager.allocate_port() + + # Mock _find_free_port to always return the already-allocated port + # This will cause all allocation attempts to be skipped (port in allocated_ports) + original_find = manager._find_free_port + + def always_return_allocated(): + return allocated_port + + manager._find_free_port = always_return_allocated + + # This should raise RuntimeError after max_attempts + with pytest.raises(RuntimeError, match="Failed to allocate a free port after .* attempts"): + manager.allocate_port(max_attempts=5) + + # Restore original method and clean up + manager._find_free_port = original_find + manager.release_port(allocated_port) + + +def test_find_free_network_port_respects_existing_master_port(with_master_port): + """find_free_network_port should reuse externally provided MASTER_PORT.""" + manager = get_port_manager() + port = with_master_port + + returned_port = find_free_network_port() + assert returned_port == port + assert port in manager._allocated_ports + + +def test_find_free_network_port_handles_invalid_master_port(with_invalid_master_port): + """Invalid MASTER_PORT values should fall back to allocating a fresh port.""" + manager = get_port_manager() + + returned_port = find_free_network_port() + assert isinstance(returned_port, int) + assert returned_port in manager._allocated_ports + assert returned_port != "not_a_valid_port_number" + + +def test_port_manager_recently_released_prevents_immediate_reuse(): + """Test that released ports enter recently_released queue and can't be immediately reallocated.""" + manager = PortManager() + + # Allocate and release a port + port = manager.allocate_port() + manager.release_port(port) + + # Port should be in recently_released queue + assert port in manager._recently_released + assert port not in manager._allocated_ports + + # Try to allocate again - should get a different port + new_port = manager.allocate_port() + assert new_port != port + assert new_port in manager._allocated_ports + + manager.release_port(new_port) + + +def test_port_manager_recently_released_queue_cycles(): + """Test that recently_released queue cycles after maxlen allocations.""" + from lightning.fabric.utilities.port_manager import _RECENTLY_RELEASED_PORTS_MAXLEN + + manager = PortManager() + + # Allocate and release a port + first_port = manager.allocate_port() + manager.release_port(first_port) + + # Port should be in recently_released queue + assert first_port in manager._recently_released + + # Allocate and release many ports to fill the queue beyond maxlen + for _ in range(_RECENTLY_RELEASED_PORTS_MAXLEN + 10): + port = manager.allocate_port() + manager.release_port(port) + + # First port should have been evicted from the queue (oldest entry) + assert first_port not in manager._recently_released + + +def test_port_manager_reserve_clears_recently_released(): + """Test that reserve_existing_port clears recently_released for that port.""" + manager = PortManager() + + # Allocate and release a port + port = manager.allocate_port() + manager.release_port(port) + + # Port should be in recently_released + assert port in manager._recently_released + + # Reserve the port - should clear from recently_released and mark as allocated + assert manager.reserve_existing_port(port) + assert port not in manager._recently_released + assert port in manager._allocated_ports + + manager.release_port(port) + + +def test_port_manager_high_queue_utilization_warning(caplog): + """Test that warning is logged when queue utilization exceeds 80%.""" + import logging + + manager = PortManager() + + # Fill queue to >80% (821/1024 = 80.2%) + for _ in range(821): + port = manager.allocate_port() + manager.release_port(port) + + # Next allocation should trigger warning + with caplog.at_level(logging.WARNING): + port = manager.allocate_port() + manager.release_port(port) + + # Verify warning was logged + assert any("Port queue utilization high" in record.message for record in caplog.records) + assert any("80." in record.message for record in caplog.records) # Should show 80.x% diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 878298c6bfd94..bb48ca8717e45 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os import signal import sys @@ -127,9 +128,25 @@ def restore_signal_handlers(): @pytest.fixture(autouse=True) def teardown_process_group(): """Ensures that the distributed process group gets closed before the next test runs.""" + import os + + from lightning.fabric.utilities.port_manager import get_port_manager + yield + + # Clean up distributed connection _destroy_dist_connection() + manager = get_port_manager() + + # If a process group created or updated MASTER_PORT during the test, reserve it and then clear it + if "MASTER_PORT" in os.environ: + with contextlib.suppress(ValueError): + port = int(os.environ["MASTER_PORT"]) + manager.reserve_existing_port(port) + manager.release_port(port) + os.environ.pop("MASTER_PORT", None) + @pytest.fixture(autouse=True) def reset_deterministic_algorithm(): @@ -332,6 +349,67 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" +def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None: + """Retry tests that fail with EADDRINUSE errors. + + This handles the race condition where a port might enter TIME_WAIT between our port allocation check and actual + binding by TCPStore. + + """ + if call.excinfo is not None and call.when == "call": + exception_msg = str(call.excinfo.value) + exception_type = str(type(call.excinfo.value).__name__) + # Check if this is an EADDRINUSE error from distributed training + # Catch both direct EADDRINUSE errors and DistNetworkError which wraps them + if ( + "EADDRINUSE" in exception_msg + or "address already in use" in exception_msg.lower() + or "DistNetworkError" in exception_type + ): + # Get the retry count from the test node + retry_count = getattr(item, "_port_retry_count", 0) + max_retries = 3 + + if retry_count < max_retries: + # Increment retry counter + item._port_retry_count = retry_count + 1 + + # Log the retry + if hasattr(item.config, "get_terminal_writer"): + writer = item.config.get_terminal_writer() + writer.write( + f"\n[Port conflict detected] Retrying test {item.name} " + f"(attempt {retry_count + 2}/{max_retries + 1})...\n", + yellow=True, + ) + + # Clear the port manager's state to get fresh ports + from lightning.fabric.utilities.port_manager import get_port_manager + + manager = get_port_manager() + manager.release_all() + + # Clear MASTER_PORT so cluster environment allocates a fresh port on retry + import os + + os.environ.pop("MASTER_PORT", None) + + # Re-run the test by raising Rerun exception + # Note: This requires pytest-rerunfailures plugin + import time + + time.sleep(1.0) # Wait for OS to release ports from TIME_WAIT state + + # If pytest-rerunfailures is available, use it + try: + from pytest_rerunfailures import Rerun + + raise Rerun(f"Port conflict (EADDRINUSE), retry {retry_count + 1}/{max_retries}") + except ImportError: + # Plugin not available, just let the test fail + pass + + def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: initial_size = len(items) conditions = []