Skip to content

Commit 2b7e4c3

Browse files
kevinmtangpytorchmergebot
authored andcommitted
[DCP] Add option to use PrefixStore to create checkpoint background process (pytorch#166560)
Summary: DCP checkpoint background process currently determines the port used for pg via get_free_port(). During checkpoint background process initialization, gloo pg init occasionally times out on the first call but succeeds in a subsequent call. We hypothesized that the timeouts are related to the port being used, and the solution would be to create the pg with PrefixStore and reuse the master port. This diff adds the option for checkpoint background process to use PrefixStore with MASTER_ADDR + MASTER_PORT. The default behavior is unchanged. Enabling the new PrefixStore behavior requires setting "DCP_USE_PREFIX_STORE" env var to "1". context: https://fb.workplace.com/groups/319878845696681/permalink/1516883985996155/ Differential Revision: D84928180 Pull Request resolved: pytorch#166560 Approved by: https://github.com/meetv18
1 parent 6c98657 commit 2b7e4c3

File tree

2 files changed

+229
-47
lines changed

2 files changed

+229
-47
lines changed

test/distributed/checkpoint/test_async_process_executor.py

Lines changed: 184 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
# Owner(s): ["oncall: distributed checkpointing"]
22

3+
import os
34
import sys
45
from unittest.mock import patch
56

67
import torch
8+
import torch.testing._internal.common_utils as common
79
from torch import distributed as dist
810
from torch.distributed.checkpoint._async_process_executor import (
911
_ProcessBasedAsyncCheckpointExecutor,
12+
_ProcessGroupInitInfo,
1013
)
14+
from torch.distributed.checkpoint.api import CheckpointException
1115
from torch.distributed.checkpoint.storage import StorageWriter
1216
from torch.distributed.elastic.utils.distributed import get_free_port
13-
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
17+
from torch.testing._internal.common_distributed import skip_if_win32
18+
from torch.testing._internal.common_utils import (
19+
retry_on_connect_failures,
20+
run_tests,
21+
TEST_WITH_DEV_DBG_ASAN,
22+
TestCase,
23+
)
1424
from torch.testing._internal.distributed._tensor.common_dtensor import (
1525
DTensorTestBase,
1626
with_comms,
@@ -110,47 +120,184 @@ def test_checkpoint_save_failure_continues_serving(self) -> None:
110120
"epoch": 5,
111121
}
112122

113-
# 1. Simulate a failure in creating PG in background process.
114-
with patch(
115-
"torch.distributed.checkpoint._async_process_executor.get_free_port",
116-
return_value=-1,
123+
with patch.dict(os.environ, {}, clear=False):
124+
os.environ.pop("DCP_USE_PREFIX_STORE", None)
125+
126+
# 1. Simulate a failure in creating PG in background process.
127+
with patch(
128+
"torch.distributed.checkpoint._async_process_executor.get_free_port",
129+
return_value=-1,
130+
):
131+
with self.assertRaises(ValueError) as _:
132+
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
133+
fut = proc_executor.execute_save(
134+
staging_future_or_state_dict=test_state_dict,
135+
)
136+
fut.result()
137+
138+
# 2. Attempt save with failing storage writer
139+
with patch(
140+
"torch.distributed.checkpoint._async_process_executor.get_free_port",
141+
return_value=get_free_port(),
142+
) as mock_get_free_port:
143+
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
144+
fut = proc_executor.execute_save(
145+
staging_future_or_state_dict=test_state_dict,
146+
storage_writer=TestStorageWriter(behavior="fail_once"),
147+
)
148+
self.assertIn(
149+
"fail_once policy triggered failure", str(fut.exception())
150+
)
151+
# Verify new process was created for this attempt
152+
if dist.get_rank() == 0:
153+
mock_get_free_port.assert_called_once()
154+
155+
# 3. Second save attempt with successful storage writer - process should still be alive
156+
with patch(
157+
"torch.distributed.checkpoint._async_process_executor.get_free_port",
158+
) as mock_get_free_port:
159+
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
160+
fut = proc_executor.execute_save(
161+
staging_future_or_state_dict=test_state_dict,
162+
storage_writer=TestStorageWriter(behavior="success"),
163+
)
164+
result = fut.result()
165+
# Verify process is still alive
166+
mock_get_free_port.assert_not_called()
167+
# Verify successful save
168+
self.assertIsNotNone(result)
169+
170+
171+
class TestAsyncProcessExecutorPrefixStore(TestCase):
172+
@skip_if_win32()
173+
@retry_on_connect_failures
174+
def test_checkpoint_save_with_prefix_store_enabled(self) -> None:
175+
"""Test that checkpoint save works when DCP_USE_PREFIX_STORE is enabled."""
176+
177+
test_state_dict = {
178+
"model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)},
179+
"optimizer": {"param_groups": [{"lr": 0.01}]},
180+
"epoch": 5,
181+
}
182+
183+
master_addr = "localhost"
184+
master_port = str(common.find_free_port())
185+
186+
with patch.dict(
187+
os.environ,
188+
{
189+
"DCP_USE_PREFIX_STORE": "1",
190+
"MASTER_ADDR": master_addr,
191+
"MASTER_PORT": master_port,
192+
},
117193
):
118-
with self.assertRaises(ValueError) as _:
194+
with patch(
195+
"torch.distributed.checkpoint._async_process_executor.get_free_port"
196+
) as mock_get_free_port:
197+
dist.init_process_group(
198+
backend=dist.Backend.GLOO,
199+
rank=0,
200+
world_size=1,
201+
)
202+
119203
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
120204
fut = proc_executor.execute_save(
121205
staging_future_or_state_dict=test_state_dict,
206+
storage_writer=TestStorageWriter(behavior="success"),
122207
)
123-
fut.result()
124-
125-
# 2. Attempt save with failing storage writer
126-
with patch(
127-
"torch.distributed.checkpoint._async_process_executor.get_free_port",
128-
return_value=get_free_port(),
129-
) as mock_get_free_port:
130-
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
131-
fut = proc_executor.execute_save(
132-
staging_future_or_state_dict=test_state_dict,
133-
storage_writer=TestStorageWriter(behavior="fail_once"),
134-
)
135-
self.assertIn("fail_once policy triggered failure", str(fut.exception()))
136-
# Verify new process was created for this attempt
137-
if dist.get_rank() == 0:
138-
mock_get_free_port.assert_called_once()
139-
140-
# 3. Second save attempt with successful storage writer - process should still be alive
141-
with patch(
142-
"torch.distributed.checkpoint._async_process_executor.get_free_port",
143-
) as mock_get_free_port:
144-
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
145-
fut = proc_executor.execute_save(
146-
staging_future_or_state_dict=test_state_dict,
147-
storage_writer=TestStorageWriter(behavior="success"),
148-
)
149-
result = fut.result()
150-
# Verify process is still alive
151-
mock_get_free_port.assert_not_called()
152-
# Verify successful save
153-
self.assertIsNotNone(result)
208+
result = fut.result()
209+
self.assertIsNotNone(result)
210+
mock_get_free_port.assert_not_called()
211+
212+
213+
class TestProcessGroupInitInfo(DTensorTestBase):
214+
"""Test suite for _ProcessGroupInitInfo."""
215+
216+
@with_comms
217+
def test_process_group_init_info_with_default_pg(self) -> None:
218+
"""Test that ProcessGroupInitInfo correctly initializes."""
219+
with patch.dict(os.environ, {}, clear=False):
220+
os.environ.pop("DCP_USE_PREFIX_STORE", None)
221+
222+
pg_init_info = _ProcessGroupInitInfo()
223+
224+
self.assertEqual(pg_init_info.global_rank, dist.get_rank())
225+
self.assertEqual(pg_init_info.world_size, dist.get_world_size())
226+
self.assertIsNotNone(pg_init_info.tcp_store_master_addr)
227+
self.assertGreater(pg_init_info.tcp_store_master_port, 0)
228+
self.assertEqual(pg_init_info.use_prefix_store, False)
229+
230+
@with_comms
231+
def test_process_group_init_info_with_prefix_store_env_var(self) -> None:
232+
"""Test that ProcessGroupInitInfo handles DCP_USE_PREFIX_STORE environment variable."""
233+
234+
# Flag enabled, addr/port correctly defined
235+
with patch.dict(
236+
os.environ,
237+
{
238+
"DCP_USE_PREFIX_STORE": "1",
239+
"MASTER_ADDR": "localhost",
240+
"MASTER_PORT": "12345",
241+
},
242+
):
243+
pg_init_info = _ProcessGroupInitInfo()
244+
self.assertTrue(pg_init_info.use_prefix_store)
245+
246+
# Missing port
247+
with patch.dict(
248+
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_ADDR": "localhost"}
249+
):
250+
with self.assertRaises(CheckpointException):
251+
pg_init_info = _ProcessGroupInitInfo()
252+
# Missing addr
253+
with patch.dict(
254+
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_PORT": "12345"}
255+
):
256+
with self.assertRaises(CheckpointException):
257+
pg_init_info = _ProcessGroupInitInfo()
258+
# Invalid port
259+
with patch.dict(
260+
os.environ,
261+
{
262+
"DCP_USE_PREFIX_STORE": "1",
263+
"MASTER_ADDR": "localhost",
264+
"MASTER_PORT": "a",
265+
},
266+
):
267+
with self.assertRaises(CheckpointException):
268+
pg_init_info = _ProcessGroupInitInfo()
269+
270+
@with_comms
271+
def test_process_group_init_info_without_prefix_store_env_var(self) -> None:
272+
"""Test that ProcessGroupInitInfo defaults to not using prefix store."""
273+
274+
# Env var set to 0
275+
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "0"}):
276+
pg_init_info = _ProcessGroupInitInfo()
277+
self.assertFalse(pg_init_info.use_prefix_store)
278+
279+
# Missing env var
280+
with patch.dict(os.environ, {}, clear=False):
281+
os.environ.pop("DCP_USE_PREFIX_STORE", None)
282+
pg_init_info = _ProcessGroupInitInfo()
283+
self.assertFalse(pg_init_info.use_prefix_store)
284+
285+
# Invalid env var
286+
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "2"}):
287+
pg_init_info = _ProcessGroupInitInfo()
288+
self.assertFalse(pg_init_info.use_prefix_store)
289+
290+
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "true"}):
291+
pg_init_info = _ProcessGroupInitInfo()
292+
self.assertFalse(pg_init_info.use_prefix_store)
293+
294+
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "false"}):
295+
pg_init_info = _ProcessGroupInitInfo()
296+
self.assertFalse(pg_init_info.use_prefix_store)
297+
298+
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": ""}):
299+
pg_init_info = _ProcessGroupInitInfo()
300+
self.assertFalse(pg_init_info.use_prefix_store)
154301

155302

156303
if __name__ == "__main__":

torch/distributed/checkpoint/_async_process_executor.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch.distributed as dist
1212
import torch.multiprocessing as mp
13+
from torch.distributed import PrefixStore, TCPStore
1314
from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor
1415
from torch.distributed.checkpoint.logger import _dcp_method_logger, _init_logger
1516
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
@@ -55,15 +56,17 @@ class _ProcessGroupInitInfo:
5556
world_size: int
5657
tcp_store_master_addr: str
5758
tcp_store_master_port: int
59+
use_prefix_store: bool
5860

5961
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
6062
self.local_rank = dist.get_node_local_rank(fallback_rank=0)
6163
self.global_rank = dist.get_rank(process_group)
6264
self.world_size = dist.get_world_size(process_group)
65+
self.use_prefix_store = os.environ.get("DCP_USE_PREFIX_STORE", "0") == "1"
6366

64-
# Let coordinator rank find a free port on the localhost.
65-
# Broadcast the (master_addr, free_port) to all ranks; each rank in the
66-
# checkpoint daemon process will use TCPStore (master_addr, master_port)
67+
# Let coordinator rank find a port on the localhost.
68+
# Broadcast the (master_addr, port) to all ranks; each rank in the
69+
# checkpoint daemon process will use TCPStore (master_addr, port)
6770
# for collective communication.
6871
dist_wrapper: _DistWrapper = _DistWrapper(
6972
group=process_group,
@@ -72,10 +75,23 @@ def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
7275
)
7376

7477
def get_master_addr_and_port() -> tuple[str, int]:
75-
master_addr = os.environ.get("MASTER_ADDR")
76-
if master_addr is None:
77-
master_addr = _get_fq_hostname()
78-
return master_addr, get_free_port()
78+
if self.use_prefix_store:
79+
master_addr = os.environ.get("MASTER_ADDR")
80+
master_port = os.environ.get("MASTER_PORT")
81+
assert master_addr is not None, (
82+
"DCP needs MASTER_ADDR to use prefix store"
83+
)
84+
assert master_port is not None, (
85+
"DCP needs MASTER_PORT to use prefix store"
86+
)
87+
master_port = int(master_port)
88+
else:
89+
master_addr = os.environ.get("MASTER_ADDR")
90+
if master_addr is None:
91+
master_addr = _get_fq_hostname()
92+
master_port = get_free_port()
93+
94+
return master_addr, master_port
7995

8096
self.tcp_store_master_addr, self.tcp_store_master_port = dist_wrapper.broadcast(
8197
step="get_master_addr_and_port",
@@ -221,10 +237,29 @@ def _checkpointing_subprocess(
221237
os.environ["WORLD_SIZE"] = str(pg_init_info.world_size)
222238

223239
logger.info(
224-
"Initializing dist.ProcessGroup in checkpoint background process"
240+
"Initializing dist.ProcessGroup in checkpoint background process on port %s",
241+
pg_init_info.tcp_store_master_port,
225242
)
226243
# NOTE: GLOO backend is enforced here.
227-
dist.init_process_group(backend=dist.Backend.GLOO)
244+
if pg_init_info.use_prefix_store:
245+
logger.info(
246+
"Initializing dist.ProcessGroup in checkpoint background process with prefix store"
247+
)
248+
store = PrefixStore(
249+
"AsyncCheckpointProcess/",
250+
TCPStore(
251+
pg_init_info.tcp_store_master_addr,
252+
pg_init_info.tcp_store_master_port,
253+
),
254+
)
255+
dist.init_process_group(
256+
backend=dist.Backend.GLOO,
257+
store=store,
258+
world_size=pg_init_info.world_size,
259+
rank=pg_init_info.global_rank,
260+
)
261+
else:
262+
dist.init_process_group(backend=dist.Backend.GLOO)
228263
dist.barrier()
229264

230265
logger.info("Checkpoint background process is running...")
@@ -365,7 +400,7 @@ def execute_save(
365400
global _CHECKPOINT_PROCESS
366401
pg_init_info: Optional[_ProcessGroupInitInfo] = None
367402
if _CHECKPOINT_PROCESS is None:
368-
# Find a free port on coordinator rank and broadcast
403+
# Find a port on coordinator rank and broadcast
369404
# to all ranks.
370405
pg_init_info = _ProcessGroupInitInfo(process_group)
371406

0 commit comments

Comments
 (0)