diff --git a/.lightning/workflows/pytorch.yml b/.lightning/workflows/pytorch.yml index 39ddaf708d0c7..6415e5a929f5b 100644 --- a/.lightning/workflows/pytorch.yml +++ b/.lightning/workflows/pytorch.yml @@ -27,7 +27,7 @@ env: DEBIAN_FRONTEND: "noninteractive" CUDA_TOOLKIT_ROOT_DIR: "/usr/local/cuda" MKL_THREADING_LAYER: "GNU" - CUDA_LAUNCH_BLOCKING: "1" + CUDA_LAUNCH_BLOCKING: "0" NCCL_DEBUG: "INFO" TORCHDYNAMO_VERBOSE: "1" FREEZE_REQUIREMENTS: "1" diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 1db30fb489b47..0e1cc944a3492 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309)) +- --- diff --git a/src/lightning/fabric/plugins/environments/lightning.py b/src/lightning/fabric/plugins/environments/lightning.py index 7f83a8527089e..e4e299ec81b4e 100644 --- a/src/lightning/fabric/plugins/environments/lightning.py +++ b/src/lightning/fabric/plugins/environments/lightning.py @@ -17,7 +17,6 @@ from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.port_manager import get_port_manager from lightning.fabric.utilities.rank_zero import rank_zero_only @@ -104,38 +103,17 @@ def teardown(self) -> None: if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] - if self._main_port != -1: - get_port_manager().release_port(self._main_port) - self._main_port = -1 - - os.environ.pop("MASTER_PORT", None) - os.environ.pop("MASTER_ADDR", None) - def find_free_network_port() -> int: - """Finds a free port on localhost. + """Finds a free port on localhost with cross-process coordination. It is useful in single-node training when we don't want to connect to a real main node but have to set the `MASTER_PORT` environment variable. - The allocated port is reserved and won't be returned by subsequent calls until it's explicitly released. - - Returns: - A port number that is reserved and free at the time of allocation + Uses file-based locking on Unix systems to prevent port conflicts between parallel pytest workers. + Falls back to simple OS allocation on Windows. """ - # If an external launcher already specified a MASTER_PORT (for example, torch.distributed.spawn or - # multiprocessing helpers), reserve it through the port manager so no other test reuses the same number. - if "MASTER_PORT" in os.environ: - master_port_str = os.environ["MASTER_PORT"] - try: - existing_port = int(master_port_str) - except ValueError: - pass - else: - port_manager = get_port_manager() - if port_manager.reserve_existing_port(existing_port): - return existing_port - - port_manager = get_port_manager() - return port_manager.allocate_port() + from lightning.fabric.utilities.port_manager import allocate_port_with_lock + + return allocate_port_with_lock() diff --git a/src/lightning/fabric/utilities/port_manager.py b/src/lightning/fabric/utilities/port_manager.py index cdf19605023d5..280822fb26929 100644 --- a/src/lightning/fabric/utilities/port_manager.py +++ b/src/lightning/fabric/utilities/port_manager.py @@ -11,223 +11,78 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Port allocation manager to prevent race conditions in distributed training.""" +"""Port allocation with retry logic for distributed training.""" -import atexit -import logging import socket -import threading -from collections import deque -from collections.abc import Iterator -from contextlib import contextmanager from typing import Optional -log = logging.getLogger(__name__) -# Size of the recently released ports queue -# This prevents immediate reuse of ports that were just released -# Set to 1024 to balance memory usage vs TIME_WAIT protection -_RECENTLY_RELEASED_PORTS_MAXLEN = 1024 - - -class PortManager: - """Thread-safe port manager to prevent EADDRINUSE errors. - - This manager maintains a global registry of allocated ports to ensure that multiple concurrent tests don't try to - use the same port. While this doesn't completely eliminate the race condition with external processes, it prevents - internal collisions within the test suite. +def _find_free_port() -> int: + """Find a free port using OS allocation.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] - """ - def __init__(self) -> None: - self._lock = threading.Lock() - self._allocated_ports: set[int] = set() - # Recently released ports are kept in a queue to avoid immediate reuse - self._recently_released: deque[int] = deque(maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN) - # Register cleanup to release all ports on exit - atexit.register(self.release_all) - - def allocate_port(self, preferred_port: Optional[int] = None, max_attempts: int = 1000) -> int: - """Allocate a free port, ensuring it's not already reserved. - - Args: - preferred_port: If provided, try to allocate this specific port first - max_attempts: Maximum number of attempts to find a free port - - Returns: - An allocated port number - - Raises: - RuntimeError: If unable to find a free port after max_attempts - - """ - with self._lock: - # If a preferred port is specified and available, use it - if ( - preferred_port is not None - and preferred_port not in self._allocated_ports - and preferred_port not in self._recently_released - and self._is_port_free(preferred_port) - ): - self._allocated_ports.add(preferred_port) - return preferred_port - - # Let the OS choose a free port, but verify it's not in our tracking structures - # The OS naturally avoids ports in TIME_WAIT (without SO_REUSEADDR) - for attempt in range(max_attempts): - port = self._find_free_port() - - # Skip if already allocated by us or recently released - # This prevents race conditions within our process - if port not in self._allocated_ports and port not in self._recently_released: - self._allocated_ports.add(port) - - # Log diagnostics if queue utilization is high - queue_count = len(self._recently_released) - if queue_count > _RECENTLY_RELEASED_PORTS_MAXLEN * 0.8: # >80% full - log.warning( - f"Port queue utilization high: {queue_count}/{_RECENTLY_RELEASED_PORTS_MAXLEN} " - f"({queue_count / _RECENTLY_RELEASED_PORTS_MAXLEN * 100:.1f}% full). " - f"Allocated port {port}. Active allocations: {len(self._allocated_ports)}" - ) - - return port - - # Provide detailed diagnostics to understand allocation failures - allocated_count = len(self._allocated_ports) - queue_count = len(self._recently_released) - queue_capacity = _RECENTLY_RELEASED_PORTS_MAXLEN - queue_utilization = (queue_count / queue_capacity * 100) if queue_capacity > 0 else 0 - - raise RuntimeError( - f"Failed to allocate a free port after {max_attempts} attempts. " - f"Diagnostics: allocated={allocated_count}, " - f"recently_released={queue_count}/{queue_capacity} ({queue_utilization:.1f}% full). " - f"If queue is near capacity, consider increasing _RECENTLY_RELEASED_PORTS_MAXLEN." - ) - - def release_port(self, port: int) -> None: - """Release a previously allocated port. - - Args: - port: Port number to release - - """ - with self._lock: - if port in self._allocated_ports: - self._allocated_ports.remove(port) - # Add to the back of the queue; oldest will be evicted when queue is full - self._recently_released.append(port) - - def release_all(self) -> None: - """Release all allocated ports.""" - with self._lock: - self._allocated_ports.clear() - self._recently_released.clear() - - def reserve_existing_port(self, port: int) -> bool: - """Reserve a port that was allocated externally. - - Args: - port: The externally assigned port to reserve. - - Returns: - True if the port was reserved (or already reserved), False if the port value is invalid. - - """ - if port <= 0 or port > 65535: - return False - - with self._lock: - if port in self._allocated_ports: - return True - - # Remove from recently released queue if present (we're explicitly reserving it) - if port in self._recently_released: - # Create a new deque without this port - self._recently_released = deque( - (p for p in self._recently_released if p != port), maxlen=_RECENTLY_RELEASED_PORTS_MAXLEN - ) - - self._allocated_ports.add(port) +def _is_port_available(port: int) -> bool: + """Check if a port is available for binding.""" + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) return True + except OSError: + return False - @contextmanager - def allocated_port(self, preferred_port: Optional[int] = None) -> Iterator[int]: - """Context manager for automatic port cleanup. - - Usage: - with manager.allocated_port() as port: - # Use port here - pass - # Port automatically released - Args: - preferred_port: Optional preferred port number +def allocate_port_with_lock(preferred_port: Optional[int] = None, max_attempts: int = 100) -> int: + """Allocate a port with retry logic for parallel process coordination. - Yields: - Allocated port number + Uses simple OS port allocation with retry attempts. This approach accepts that + there's an inherent race condition between allocating a port and actually binding to it, + and handles it through retries rather than attempting to prevent it. - """ - port = self.allocate_port(preferred_port=preferred_port) - try: - yield port - finally: - self.release_port(port) + The race condition occurs because: + 1. We ask OS for a port → get port 50435 + 2. We close socket to return the port number + 3. Another process can grab port 50435 here ← RACE WINDOW + 4. TCPStore tries to bind → EADDRINUSE - @staticmethod - def _find_free_port() -> int: - """Find a free port using OS allocation. + This is unfixable without keeping the socket open, which isn't possible + when we only return a port number. File locks don't help because they can't + prevent the OS from reusing a port. - Returns: - A port number that was free at the time of checking - - """ - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # Don't use SO_REUSEADDR - we need to match the behavior of TCPStore - # which binds without it, so ports in TIME_WAIT will be rejected - s.bind(("", 0)) - port = s.getsockname()[1] - s.close() - return port + Args: + preferred_port: Try to use this port first if available + max_attempts: Maximum number of allocation attempts - @staticmethod - def _is_port_free(port: int) -> bool: - """Check if a specific port is available. - - Args: - port: Port number to check - - Returns: - True if the port is free, False otherwise + Returns: + An available port number - """ - try: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # Don't use SO_REUSEADDR - we need to match the behavior of TCPStore - # which binds without it, so ports in TIME_WAIT will be rejected - s.bind(("", port)) - s.close() - return True - except OSError: - return False + Raises: + RuntimeError: If unable to allocate a port after max_attempts + """ + # Try preferred port first + if preferred_port and _is_port_available(preferred_port): + return preferred_port -# Global singleton instance -_port_manager: Optional[PortManager] = None -_port_manager_lock = threading.Lock() + # Simple OS allocation - let the kernel choose + # Multiple attempts help reduce collision probability when many parallel processes + # are allocating ports simultaneously + for attempt in range(max_attempts): + port = _find_free_port() + # Small random delay to reduce collision probability with parallel processes + # Only sleep on retry attempts, not the first try + if attempt > 0: + import random + import time -def get_port_manager() -> PortManager: - """Get or create the global port manager instance. + time.sleep(random.uniform(0.001, 0.01)) # noqa: S311 - Returns: - The global PortManager singleton + # Verify port is still available (best effort) + if _is_port_available(port): + return port - """ - global _port_manager - if _port_manager is None: - with _port_manager_lock: - if _port_manager is None: - _port_manager = PortManager() - return _port_manager + raise RuntimeError(f"Failed to allocate a free port after {max_attempts} attempts") diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index a06bb0eacdbb4..9d4a0b9462f2e 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import os import sys import threading @@ -78,25 +77,9 @@ def restore_env_variables(): @pytest.fixture(autouse=True) def teardown_process_group(): """Ensures that the distributed process group gets closed before the next test runs.""" - import os - - from lightning.fabric.utilities.port_manager import get_port_manager - yield - - # Clean up distributed connection _destroy_dist_connection() - manager = get_port_manager() - - # If a process group created or updated MASTER_PORT during the test, reserve it and then clear it - if "MASTER_PORT" in os.environ: - with contextlib.suppress(ValueError): - port = int(os.environ["MASTER_PORT"]) - manager.reserve_existing_port(port) - manager.release_port(port) - os.environ.pop("MASTER_PORT", None) - @pytest.fixture(autouse=True) def thread_police_duuu_daaa_duuu_daaa(): @@ -221,67 +204,6 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" -def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None: - """Retry tests that fail with EADDRINUSE errors. - - This handles the race condition where a port might enter TIME_WAIT between our port allocation check and actual - binding by TCPStore. - - """ - if call.excinfo is not None and call.when == "call": - exception_msg = str(call.excinfo.value) - exception_type = str(type(call.excinfo.value).__name__) - # Check if this is an EADDRINUSE error from distributed training - # Catch both direct EADDRINUSE errors and DistNetworkError which wraps them - if ( - "EADDRINUSE" in exception_msg - or "address already in use" in exception_msg.lower() - or "DistNetworkError" in exception_type - ): - # Get the retry count from the test node - retry_count = getattr(item, "_port_retry_count", 0) - max_retries = 3 - - if retry_count < max_retries: - # Increment retry counter - item._port_retry_count = retry_count + 1 - - # Log the retry - if hasattr(item.config, "get_terminal_writer"): - writer = item.config.get_terminal_writer() - writer.write( - f"\n[Port conflict detected] Retrying test {item.name} " - f"(attempt {retry_count + 2}/{max_retries + 1})...\n", - yellow=True, - ) - - # Clear the port manager's state to get fresh ports - from lightning.fabric.utilities.port_manager import get_port_manager - - manager = get_port_manager() - manager.release_all() - - # Clear MASTER_PORT so cluster environment allocates a fresh port on retry - import os - - os.environ.pop("MASTER_PORT", None) - - # Re-run the test by raising Rerun exception - # Note: This requires pytest-rerunfailures plugin - import time - - time.sleep(1.0) # Wait for OS to release ports from TIME_WAIT state - - # If pytest-rerunfailures is available, use it - try: - from pytest_rerunfailures import Rerun - - raise Rerun(f"Port conflict (EADDRINUSE), retry {retry_count + 1}/{max_retries}") - except ImportError: - # Plugin not available, just let the test fail - pass - - def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: """An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`""" initial_size = len(items) diff --git a/tests/tests_fabric/plugins/environments/test_lightning.py b/tests/tests_fabric/plugins/environments/test_lightning.py index aa225475289f8..cc0179af4c3f5 100644 --- a/tests/tests_fabric/plugins/environments/test_lightning.py +++ b/tests/tests_fabric/plugins/environments/test_lightning.py @@ -17,7 +17,6 @@ import pytest from lightning.fabric.plugins.environments import LightningEnvironment -from lightning.fabric.utilities.port_manager import get_port_manager @mock.patch.dict(os.environ, {}, clear=True) @@ -85,18 +84,5 @@ def test_teardown(): assert "WORLD_SIZE" not in os.environ -@mock.patch.dict(os.environ, {}, clear=True) -def test_teardown_releases_port_and_env(): - env = LightningEnvironment() - port = env.main_port - assert port in get_port_manager()._allocated_ports - - env.teardown() - - assert port not in get_port_manager()._allocated_ports - assert "MASTER_PORT" not in os.environ - assert "MASTER_ADDR" not in os.environ - - def test_detect(): assert LightningEnvironment.detect() diff --git a/tests/tests_fabric/utilities/test_port_manager.py b/tests/tests_fabric/utilities/test_port_manager.py deleted file mode 100644 index 8ea820baa9a42..0000000000000 --- a/tests/tests_fabric/utilities/test_port_manager.py +++ /dev/null @@ -1,776 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for the PortManager utility and port allocation integration.""" - -import os -import socket -import threading -from collections import Counter - -import pytest - -from lightning.fabric.plugins.environments.lightning import find_free_network_port -from lightning.fabric.utilities.port_manager import PortManager, get_port_manager - -# ============================================================================= -# Fixtures -# ============================================================================= - - -@pytest.fixture -def with_master_port(): - """Fixture that sets MASTER_PORT before test runs, for conftest coverage.""" - port = find_free_network_port() - previous_value = os.environ.get("MASTER_PORT") - os.environ["MASTER_PORT"] = str(port) - try: - yield port - finally: - if previous_value is None: - os.environ.pop("MASTER_PORT", None) - else: - os.environ["MASTER_PORT"] = previous_value - - -@pytest.fixture -def with_invalid_master_port(): - """Fixture that sets invalid MASTER_PORT to test error handling.""" - previous_value = os.environ.get("MASTER_PORT") - os.environ["MASTER_PORT"] = "not_a_valid_port_number" - try: - yield - finally: - if previous_value is None: - os.environ.pop("MASTER_PORT", None) - else: - os.environ["MASTER_PORT"] = previous_value - - -# ============================================================================= -# Unit Tests for PortManager -# ============================================================================= - - -def test_port_manager_allocates_unique_ports(): - """Test that PortManager allocates unique ports.""" - manager = PortManager() - - # Allocate multiple ports - ports = [manager.allocate_port() for _ in range(10)] - - # All ports should be unique - assert len(ports) == len(set(ports)), f"Duplicate ports found: {ports}" - - # All ports should be valid (>= 1024) - assert all(p >= 1024 for p in ports), "Some ports are in reserved range" - - -def test_port_manager_release_port(): - """Test that released ports are removed from the allocated set.""" - manager = PortManager() - - # Allocate a port - port = manager.allocate_port() - assert port in manager._allocated_ports - - # Release the port - manager.release_port(port) - assert port not in manager._allocated_ports - - -def test_port_manager_release_all(): - """Test that release_all clears all allocated ports.""" - manager = PortManager() - - # Allocate multiple ports - [manager.allocate_port() for _ in range(5)] - assert len(manager._allocated_ports) == 5 - - # Release all - manager.release_all() - assert len(manager._allocated_ports) == 0 - - -def test_port_manager_release_nonexistent_port(): - """Test that releasing a non-existent port doesn't cause errors.""" - manager = PortManager() - - # Try to release a port that was never allocated - manager.release_port(12345) # Should not raise an error - - # Verify nothing broke - port = manager.allocate_port() - assert port >= 1024 - - -def test_port_manager_thread_safety(): - """Test that PortManager is thread-safe under concurrent access.""" - manager = PortManager() - ports = [] - lock = threading.Lock() - - def allocate_ports(): - """Allocate multiple ports from different threads.""" - for _ in range(10): - port = manager.allocate_port() - with lock: - ports.append(port) - - # Create multiple threads that allocate ports concurrently - threads = [threading.Thread(target=allocate_ports) for _ in range(10)] - - # Start all threads - for thread in threads: - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify we got 100 unique ports (10 threads × 10 ports each) - assert len(ports) == 100, f"Expected 100 ports, got {len(ports)}" - assert len(set(ports)) == 100, f"Expected 100 unique ports, got {len(set(ports))}" - - # Check for any duplicates - counts = Counter(ports) - duplicates = {port: count for port, count in counts.items() if count > 1} - assert not duplicates, f"Found duplicate ports: {duplicates}" - - -def test_port_manager_preferred_port(): - """Test that PortManager can allocate a preferred port if available.""" - manager = PortManager() - - # Try to find a free port first - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - free_port = s.getsockname()[1] - s.close() - - # Allocate the preferred port - allocated = manager.allocate_port(preferred_port=free_port) - assert allocated == free_port - - # Try to allocate the same preferred port again (should get a different one) - allocated2 = manager.allocate_port(preferred_port=free_port) - assert allocated2 != free_port - - -def test_port_manager_allocation_failure(): - """Test that PortManager raises error when unable to allocate after max attempts.""" - manager = PortManager() - - # Pre-allocate a large number of ports to make it harder to find a free one - # Then try with max_attempts=1 which should fail quickly - allocated_ports = [manager.allocate_port() for _ in range(50)] - - # Test that it can still allocate with enough attempts - port = manager.allocate_port(max_attempts=100) - assert port >= 1024 - - # Clean up - for p in allocated_ports: - manager.release_port(p) - manager.release_port(port) - - -def test_port_manager_prevents_reallocation(): - """Test that a port won't be allocated twice until released.""" - manager = PortManager() - - # Allocate a port - port1 = manager.allocate_port() - - # Allocate many more ports - none should match port1 - more_ports = [manager.allocate_port() for _ in range(50)] - - # port1 should not appear in more_ports - assert port1 not in more_ports, f"Port {port1} was reallocated before release" - - # After releasing port1, we should eventually be able to get it again - # (though not guaranteed due to OS port allocation) - manager.release_port(port1) - assert port1 not in manager._allocated_ports - - # Clean up - for port in more_ports: - manager.release_port(port) - - -def test_get_port_manager_singleton(): - """Test that get_port_manager returns the same instance.""" - manager1 = get_port_manager() - manager2 = get_port_manager() - - # Should be the same instance - assert manager1 is manager2 - - # Allocating from one should be visible in the other - port = manager1.allocate_port() - assert port in manager2._allocated_ports - - # Clean up - manager1.release_port(port) - - -def test_get_port_manager_thread_safe_singleton(): - """Test that get_port_manager creates singleton safely across threads.""" - managers = [] - lock = threading.Lock() - - def get_manager(): - manager = get_port_manager() - with lock: - managers.append(manager) - - # Create multiple threads that get the port manager - threads = [threading.Thread(target=get_manager) for _ in range(20)] - - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - # All should be the same instance - assert len(managers) == 20 - assert all(m is managers[0] for m in managers), "get_port_manager returned different instances" - - -def test_port_manager_is_port_free(): - """Test the _is_port_free helper method.""" - manager = PortManager() - - # Find a free port using OS - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - free_port = s.getsockname()[1] - - # Port should be reported as not free while socket is open - assert not manager._is_port_free(free_port) - - # Close the socket - s.close() - - # Now port should be free (though there's still a small race condition) - # We'll skip this check as it's unreliable due to OS behavior - - -def test_port_manager_find_free_port(): - """Test the _find_free_port helper method.""" - manager = PortManager() - - # Should return a valid port - port = manager._find_free_port() - assert isinstance(port, int) - assert port >= 1024 - assert port <= 65535 - - -def test_port_manager_concurrent_allocation_and_release(): - """Test concurrent allocation and release operations.""" - manager = PortManager() - ports = [] - lock = threading.Lock() - active_ports: set[int] = set() - - def allocate_and_release(): - for _ in range(5): - # Allocate a port - port = manager.allocate_port() - with lock: - assert port not in active_ports, "Port allocated concurrently" - active_ports.add(port) - ports.append(port) - - # Release it immediately - manager.release_port(port) - with lock: - active_ports.remove(port) - - # Run multiple threads - threads = [threading.Thread(target=allocate_and_release) for _ in range(10)] - - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - # Should have allocated 50 ports total (10 threads × 5 ports) - assert len(ports) == 50 - - # After all releases, manager should have no ports allocated - assert len(manager._allocated_ports) == 0 - - -def test_port_manager_atexit_cleanup(): - """Test that PortManager registers atexit cleanup.""" - - # Create a new manager - manager = PortManager() - - # The manager should have registered release_all with atexit - # We can't easily test atexit directly, but we can verify the method exists - assert callable(manager.release_all) - - # Verify release_all works - manager.allocate_port() - manager.allocate_port() - assert len(manager._allocated_ports) == 2 - - manager.release_all() - assert len(manager._allocated_ports) == 0 - - -def test_port_manager_reserve_existing_port_free(): - """reserve_existing_port should succeed for free ports and track them.""" - manager = PortManager() - - port = manager._find_free_port() - assert manager.reserve_existing_port(port) - assert port in manager._allocated_ports - - # Second call should succeed but not duplicate - assert manager.reserve_existing_port(port) - assert len(manager._allocated_ports) == 1 - - -def test_port_manager_reserve_existing_port_invalid_value(): - """reserve_existing_port should reject invalid port numbers.""" - manager = PortManager() - - assert not manager.reserve_existing_port(0) - assert not manager.reserve_existing_port(-1) - assert not manager.reserve_existing_port(70000) - - -def test_port_manager_reserve_existing_port_after_release(): - """Ports released from sockets should become reservable.""" - manager = PortManager() - - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("", 0)) - reusable_port = s.getsockname()[1] - s.close() - - assert manager.reserve_existing_port(reusable_port) - assert reusable_port in manager._allocated_ports - - -def test_port_manager_context_manager(): - """Test that context manager automatically releases ports.""" - manager = PortManager() - - # Use context manager - with manager.allocated_port() as port: - # Port should be allocated - assert port in manager._allocated_ports - assert isinstance(port, int) - assert port >= 1024 - - # After context, port should be released - assert port not in manager._allocated_ports - - -def test_port_manager_context_manager_exception(): - """Test that context manager releases port even on exception.""" - manager = PortManager() - - try: - with manager.allocated_port() as port: - allocated_port = port - # Port should be allocated - assert port in manager._allocated_ports - # Raise exception - raise ValueError("Test exception") - except ValueError: - pass - - # Port should still be released despite exception - assert allocated_port not in manager._allocated_ports - - -def test_port_manager_context_manager_nested(): - """Test that nested context managers work correctly.""" - manager = PortManager() - - with manager.allocated_port() as port1: - assert port1 in manager._allocated_ports - - with manager.allocated_port() as port2: - # Both ports should be allocated - assert port1 in manager._allocated_ports - assert port2 in manager._allocated_ports - # Ports should be different - assert port1 != port2 - - # port2 should be released, port1 still allocated - assert port1 in manager._allocated_ports - assert port2 not in manager._allocated_ports - - # Both ports should now be released - assert port1 not in manager._allocated_ports - assert port2 not in manager._allocated_ports - - -# ============================================================================= -# Integration Tests for find_free_network_port() -# ============================================================================= - - -def test_find_free_network_port_uses_port_manager(): - """Test that find_free_network_port uses the PortManager.""" - manager = get_port_manager() - - # Clear any previously allocated ports - initial_count = len(manager._allocated_ports) - - # Allocate a port using the function - port = find_free_network_port() - - # The port should be in the manager's allocated set - assert port in manager._allocated_ports - assert len(manager._allocated_ports) == initial_count + 1 - - # Clean up - manager.release_port(port) - - -def test_find_free_network_port_returns_unique_ports(): - """Test that multiple calls return unique ports.""" - manager = get_port_manager() - - # Allocate multiple ports - ports = [find_free_network_port() for _ in range(10)] - - # All should be unique - assert len(ports) == len(set(ports)), f"Duplicate ports: {ports}" - - # All should be tracked by the manager - for port in ports: - assert port in manager._allocated_ports - - # Clean up - for port in ports: - manager.release_port(port) - - -def test_find_free_network_port_thread_safety(): - """Test that find_free_network_port is thread-safe.""" - ports = [] - lock = threading.Lock() - - def allocate(): - for _ in range(5): - port = find_free_network_port() - with lock: - ports.append(port) - - # Run 10 threads, each allocating 5 ports - threads = [threading.Thread(target=allocate) for _ in range(10)] - - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - # Should have 50 unique ports - assert len(ports) == 50 - assert len(set(ports)) == 50, "Duplicate ports allocated across threads" - - # Clean up - manager = get_port_manager() - for port in ports: - manager.release_port(port) - - -def test_port_allocation_simulates_distributed_test_lifecycle(): - """Simulate the lifecycle of a distributed test with port allocation and release.""" - manager = get_port_manager() - initial_count = len(manager._allocated_ports) - - # Simulate test setup: allocate a port - port = find_free_network_port() - os.environ["MASTER_PORT"] = str(port) - - # Verify port is allocated - assert port in manager._allocated_ports - - # Simulate test teardown: release the port - if "MASTER_PORT" in os.environ: - port_to_release = int(os.environ["MASTER_PORT"]) - manager.release_port(port_to_release) - del os.environ["MASTER_PORT"] - - # Verify port is released - assert port not in manager._allocated_ports - assert len(manager._allocated_ports) == initial_count - - -def test_conftest_cleanup_with_master_port_set(with_master_port): - """Test conftest cleanup when MASTER_PORT is set before test starts. - - This test uses a fixture to set MASTER_PORT before the test runs, allowing the conftest teardown_process_group - fixture to capture and clean it up. This ensures the conftest cleanup code is covered. - - """ - manager = get_port_manager() - port = with_master_port # Port was set by fixture - - # Verify port is allocated - assert port in manager._allocated_ports - assert os.environ.get("MASTER_PORT") == str(port) - - # Leave MASTER_PORT set - conftest teardown will clean it up - # After this test, teardown_process_group will: - # 1. Detect MASTER_PORT in os.environ (line captured before yield) - # 2. Call get_port_manager().release_port(port) - # 3. Port gets released back to manager - - -def test_conftest_handles_invalid_master_port(with_invalid_master_port): - """Test conftest handles invalid MASTER_PORT gracefully. - - This exercises the contextlib.suppress(ValueError, KeyError) path in the conftest teardown_process_group fixture. - - """ - # Fixture set MASTER_PORT to "not_a_valid_port_number" - # The conftest will try to parse it: int(os.environ["MASTER_PORT"]) - # This will raise ValueError, which should be caught by contextlib.suppress - - # Verify the invalid value is set - assert os.environ.get("MASTER_PORT") == "not_a_valid_port_number" - - # This test just needs to complete without crashing - # The conftest teardown will handle the ValueError gracefully - - -def test_multiple_tests_can_reuse_ports_after_release(): - """Test that ports can be reused after being released.""" - manager = get_port_manager() - - # First "test" allocates a port - port1 = find_free_network_port() - assert port1 in manager._allocated_ports - - # First "test" completes and releases the port - manager.release_port(port1) - assert port1 not in manager._allocated_ports - - # Second "test" allocates ports (may or may not get the same port) - port2 = find_free_network_port() - assert port2 in manager._allocated_ports - - # Ports should be valid regardless - assert port1 >= 1024 - assert port2 >= 1024 - - # Clean up - manager.release_port(port2) - - -def test_concurrent_tests_dont_get_same_port(): - """Test that concurrent tests never receive the same port.""" - manager = get_port_manager() - ports_per_thread = [] - lock = threading.Lock() - - def simulate_test(): - """Simulate a test that allocates a port, uses it, then releases it.""" - my_ports = [] - - # Allocate port for this "test" - port = find_free_network_port() - my_ports.append(port) - - # Simulate some work - import time - - time.sleep(0.001) - - # Release port after "test" completes - manager.release_port(port) - - with lock: - ports_per_thread.append(my_ports) - - # Run 20 concurrent "tests" - threads = [threading.Thread(target=simulate_test) for _ in range(20)] - - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - # Collect all ports that were allocated - all_ports = [port for thread_ports in ports_per_thread for port in thread_ports] - - # All ports should have been unique at the time of allocation - assert len(all_ports) == 20 - assert len(set(all_ports)) == 20, "Some concurrent tests got the same port!" - - -def test_port_manager_survives_multiple_test_sessions(): - """Test that the port manager maintains state across multiple test sessions.""" - manager = get_port_manager() - - # Session 1: Allocate some ports - session1_ports = [find_free_network_port() for _ in range(3)] - - # Session 2: Allocate more ports (should not overlap with session 1) - session2_ports = [find_free_network_port() for _ in range(3)] - - # No overlap between sessions while both are active - assert not set(session1_ports) & set(session2_ports) - - # Release session 1 ports - for port in session1_ports: - manager.release_port(port) - - # Session 3: Can allocate more ports - session3_ports = [find_free_network_port() for _ in range(3)] - - # Session 3 shouldn't overlap with active session 2 - assert not set(session2_ports) & set(session3_ports) - - # Clean up - for port in session2_ports + session3_ports: - manager.release_port(port) - - -def test_port_manager_allocation_runtime_error(): - """Test that allocation fails gracefully when max_attempts is exhausted.""" - manager = PortManager() - - # Allocate a port first - allocated_port = manager.allocate_port() - - # Mock _find_free_port to always return the already-allocated port - # This will cause all allocation attempts to be skipped (port in allocated_ports) - original_find = manager._find_free_port - - def always_return_allocated(): - return allocated_port - - manager._find_free_port = always_return_allocated - - # This should raise RuntimeError after max_attempts - with pytest.raises(RuntimeError, match="Failed to allocate a free port after .* attempts"): - manager.allocate_port(max_attempts=5) - - # Restore original method and clean up - manager._find_free_port = original_find - manager.release_port(allocated_port) - - -def test_find_free_network_port_respects_existing_master_port(with_master_port): - """find_free_network_port should reuse externally provided MASTER_PORT.""" - manager = get_port_manager() - port = with_master_port - - returned_port = find_free_network_port() - assert returned_port == port - assert port in manager._allocated_ports - - -def test_find_free_network_port_handles_invalid_master_port(with_invalid_master_port): - """Invalid MASTER_PORT values should fall back to allocating a fresh port.""" - manager = get_port_manager() - - returned_port = find_free_network_port() - assert isinstance(returned_port, int) - assert returned_port in manager._allocated_ports - assert returned_port != "not_a_valid_port_number" - - -def test_port_manager_recently_released_prevents_immediate_reuse(): - """Test that released ports enter recently_released queue and can't be immediately reallocated.""" - manager = PortManager() - - # Allocate and release a port - port = manager.allocate_port() - manager.release_port(port) - - # Port should be in recently_released queue - assert port in manager._recently_released - assert port not in manager._allocated_ports - - # Try to allocate again - should get a different port - new_port = manager.allocate_port() - assert new_port != port - assert new_port in manager._allocated_ports - - manager.release_port(new_port) - - -def test_port_manager_recently_released_queue_cycles(): - """Test that recently_released queue cycles after maxlen allocations.""" - from lightning.fabric.utilities.port_manager import _RECENTLY_RELEASED_PORTS_MAXLEN - - manager = PortManager() - - # Allocate and release a port - first_port = manager.allocate_port() - manager.release_port(first_port) - - # Port should be in recently_released queue - assert first_port in manager._recently_released - - # Allocate and release many ports to fill the queue beyond maxlen - for _ in range(_RECENTLY_RELEASED_PORTS_MAXLEN + 10): - port = manager.allocate_port() - manager.release_port(port) - - # First port should have been evicted from the queue (oldest entry) - assert first_port not in manager._recently_released - - -def test_port_manager_reserve_clears_recently_released(): - """Test that reserve_existing_port clears recently_released for that port.""" - manager = PortManager() - - # Allocate and release a port - port = manager.allocate_port() - manager.release_port(port) - - # Port should be in recently_released - assert port in manager._recently_released - - # Reserve the port - should clear from recently_released and mark as allocated - assert manager.reserve_existing_port(port) - assert port not in manager._recently_released - assert port in manager._allocated_ports - - manager.release_port(port) - - -def test_port_manager_high_queue_utilization_warning(caplog): - """Test that warning is logged when queue utilization exceeds 80%.""" - import logging - - manager = PortManager() - - # Fill queue to >80% (821/1024 = 80.2%) - for _ in range(821): - port = manager.allocate_port() - manager.release_port(port) - - # Next allocation should trigger warning - with caplog.at_level(logging.WARNING): - port = manager.allocate_port() - manager.release_port(port) - - # Verify warning was logged - assert any("Port queue utilization high" in record.message for record in caplog.records) - assert any("80." in record.message for record in caplog.records) # Should show 80.x% diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index bb48ca8717e45..878298c6bfd94 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import os import signal import sys @@ -128,25 +127,9 @@ def restore_signal_handlers(): @pytest.fixture(autouse=True) def teardown_process_group(): """Ensures that the distributed process group gets closed before the next test runs.""" - import os - - from lightning.fabric.utilities.port_manager import get_port_manager - yield - - # Clean up distributed connection _destroy_dist_connection() - manager = get_port_manager() - - # If a process group created or updated MASTER_PORT during the test, reserve it and then clear it - if "MASTER_PORT" in os.environ: - with contextlib.suppress(ValueError): - port = int(os.environ["MASTER_PORT"]) - manager.reserve_existing_port(port) - manager.release_port(port) - os.environ.pop("MASTER_PORT", None) - @pytest.fixture(autouse=True) def reset_deterministic_algorithm(): @@ -349,67 +332,6 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" -def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None: - """Retry tests that fail with EADDRINUSE errors. - - This handles the race condition where a port might enter TIME_WAIT between our port allocation check and actual - binding by TCPStore. - - """ - if call.excinfo is not None and call.when == "call": - exception_msg = str(call.excinfo.value) - exception_type = str(type(call.excinfo.value).__name__) - # Check if this is an EADDRINUSE error from distributed training - # Catch both direct EADDRINUSE errors and DistNetworkError which wraps them - if ( - "EADDRINUSE" in exception_msg - or "address already in use" in exception_msg.lower() - or "DistNetworkError" in exception_type - ): - # Get the retry count from the test node - retry_count = getattr(item, "_port_retry_count", 0) - max_retries = 3 - - if retry_count < max_retries: - # Increment retry counter - item._port_retry_count = retry_count + 1 - - # Log the retry - if hasattr(item.config, "get_terminal_writer"): - writer = item.config.get_terminal_writer() - writer.write( - f"\n[Port conflict detected] Retrying test {item.name} " - f"(attempt {retry_count + 2}/{max_retries + 1})...\n", - yellow=True, - ) - - # Clear the port manager's state to get fresh ports - from lightning.fabric.utilities.port_manager import get_port_manager - - manager = get_port_manager() - manager.release_all() - - # Clear MASTER_PORT so cluster environment allocates a fresh port on retry - import os - - os.environ.pop("MASTER_PORT", None) - - # Re-run the test by raising Rerun exception - # Note: This requires pytest-rerunfailures plugin - import time - - time.sleep(1.0) # Wait for OS to release ports from TIME_WAIT state - - # If pytest-rerunfailures is available, use it - try: - from pytest_rerunfailures import Rerun - - raise Rerun(f"Port conflict (EADDRINUSE), retry {retry_count + 1}/{max_retries}") - except ImportError: - # Plugin not available, just let the test fail - pass - - def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: initial_size = len(items) conditions = []