|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import os |
16 | | -import socket |
17 | 16 |
|
18 | 17 | from typing_extensions import override |
19 | 18 |
|
20 | 19 | from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment |
| 20 | +from lightning.fabric.utilities.port_manager import get_port_manager |
21 | 21 | from lightning.fabric.utilities.rank_zero import rank_zero_only |
22 | 22 |
|
23 | 23 |
|
@@ -104,16 +104,38 @@ def teardown(self) -> None: |
104 | 104 | if "WORLD_SIZE" in os.environ: |
105 | 105 | del os.environ["WORLD_SIZE"] |
106 | 106 |
|
| 107 | + if self._main_port != -1: |
| 108 | + get_port_manager().release_port(self._main_port) |
| 109 | + self._main_port = -1 |
| 110 | + |
| 111 | + os.environ.pop("MASTER_PORT", None) |
| 112 | + os.environ.pop("MASTER_ADDR", None) |
| 113 | + |
107 | 114 |
|
108 | 115 | def find_free_network_port() -> int: |
109 | 116 | """Finds a free port on localhost. |
110 | 117 |
|
111 | 118 | It is useful in single-node training when we don't want to connect to a real main node but have to set the |
112 | 119 | `MASTER_PORT` environment variable. |
113 | 120 |
|
| 121 | + The allocated port is reserved and won't be returned by subsequent calls until it's explicitly released. |
| 122 | +
|
| 123 | + Returns: |
| 124 | + A port number that is reserved and free at the time of allocation |
| 125 | +
|
114 | 126 | """ |
115 | | - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
116 | | - s.bind(("", 0)) |
117 | | - port = s.getsockname()[1] |
118 | | - s.close() |
119 | | - return port |
| 127 | + # If an external launcher already specified a MASTER_PORT (for example, torch.distributed.spawn or |
| 128 | + # multiprocessing helpers), reserve it through the port manager so no other test reuses the same number. |
| 129 | + if "MASTER_PORT" in os.environ: |
| 130 | + master_port_str = os.environ["MASTER_PORT"] |
| 131 | + try: |
| 132 | + existing_port = int(master_port_str) |
| 133 | + except ValueError: |
| 134 | + pass |
| 135 | + else: |
| 136 | + port_manager = get_port_manager() |
| 137 | + if port_manager.reserve_existing_port(existing_port): |
| 138 | + return existing_port |
| 139 | + |
| 140 | + port_manager = get_port_manager() |
| 141 | + return port_manager.allocate_port() |
0 commit comments