Skip to content

Commit 79343ef

Browse files
meichangsu101191421BalaBalaYi
authored
support deepspeed elastic (#1672)
* support elastic speed * support elastic speed * support elastic speed * add doc for deepspeed elastic * design doc * feat: add support for additional arguments in DeepSpeed checkpoint save method Add **kwargs parameter to the save method in AsyncCheckpointAgent to allow passing additional arguments, such as storage options, when saving checkpoints. This enhances flexibility for future extensions without breaking existing functionality. * feat: remove model state skip logic in checkpoint saver The condition to skip saving model states when `ckpt_config.write_model` is false has been removed. This change ensures that all specified states in `ckpt_config.paths` are saved regardless of the `write_model` flag, aligning the behavior with the configuration paths and preventing unintended omissions during checkpoint operations. * The condition to skip saving model states when `ckpt_config.write_model` is false has been removed. This change ensures that all specified states in `ckpt_config.paths` are saved regardless of the `write_model` flag, aligning the behavior with the configuration paths and preventing unintended omissions during checkpoint operations. Additionally, code formatting and style improvements were applied across multiple files, including: - Added missing newlines between class definitions in `comm.py` - Standardized spacing around operators and commas in `master_client.py` and `ckpt_saver.py` - Updated Chinese comments to English in `ckpt_saver.py` - Enhanced test coverage for graceful worker exit scenarios in `torch_agent_test.py` * feat: improve universal checkpoint logic and code formatting * feat: reduce checkpoint waiting timeout from 300 to 60 seconds feat: reduce checkpoint waiting timeout from 300 to 60 seconds * feat: fix shared memory handling and add rank parameter to checkpoint saver - Set shared_memory to None after closing to prevent reuse of closed memory - Add check for shared_memory.buf in SharedMemoryHandler.get() to avoid errors - Add rank parameter to TempDirCheckpointSaver.__init__ for proper initialization - Fix test formatting and remove unused variable in checkpoint saver tests * feat: fix lint and ci * feat: fix lint and ci * feat: fix lint and ci * feat(ckpt_saver): skip persistence when checkpoint config is missing Add a guard clause in `persist_to_storage` to skip the persistence operation if the checkpoint config is `None` or has no paths. This prevents potential errors when the checkpoint configuration is incomplete or unavailable, ensuring the saver handles missing configurations gracefully. * refactor: rename UCP-related classes and methods for clarity - Rename `UCPReady` to `PreviousRoundCompleted` to better reflect its purpose of indicating previous rendezvous round completion - Rename `UCPReadyRequest` to `PreviousRoundCompletedRequest` for consistency - Update all related method names (`get_ucp_ready`, `set_ucp_ready`) to `get_previous_round_completed` and `set_previous_round_completed` - Rename instance variable `ucp_ready` to `previous_round_completed` in `RendezvousManager` - Improve documentation strings to clarify the purpose of tracking previous round completion status * feat: replace DLROVER_UCP_RESTART with DLROVER_TRAINING_ELASTIC_MODE - Update environment variable from DLROVER_UCP_RESTART to DLROVER_TRAINING_ELASTIC_MODE for better clarity and extensibility - Change condition checks from `enable_ucp == "true"` to `elastic_mode == "ucp"` to support multiple elastic training modes - Remove previous_round_completed logic from base RendezvousManager to simplify state management - Introduce create_training_rdzv_manager factory function for flexible rendezvous manager creation - Centralize elastic mode configuration through environment variable for consistent behavior across components * refactor: replace previous round completion with rendezvous blocking - Rename `PreviousRoundCompleted` message to `RdzvBlocked` with `blocked` boolean and `reason` string fields - Remove `PreviousRoundCompletedRequest` message as it is no longer used - Update `MasterClient` methods to use `set_rdzv_blocked` instead of `set_previous_round_completed` and remove `get_previous_round_completed` - Modify `ElasticTrainingAgent` to call `set_rdzv_blocked` when UCP elastic mode is active - Add `_rdzv_blocked` and `_rdzv_block_reason` state to `RendezvousManager` with corresponding setter and getter methods - Update `_pre_rdzv_check_hook` to return the rendezvous blocked state and reason This change simplifies the rendezvous state tracking by consolidating completion status into a blocking mechanism with an optional reason, improving clarity and flexibility for UCP elastic training scenarios. * refactor: remove unnecessary blank lines and unused imports - Remove extra blank lines in `comm.py`, `rdzv_manager.py`, and test files to improve code readability - Remove unused import of `ElasticTrainingRendezvousManager` in `test_servicer.py` to clean up dependencies * lint fix * refactor seen_new_saving -> need_new_saving * feat: add training elastic mode configuration for rendezvous manager - Add DLROVER_TRAINING_ELASTIC_MODE environment variable constant - Introduce training_elastic_mode default value and context attribute - Add --training_elastic_mode argument to master CLI with default "base" - Update rendezvous manager factory to use context instead of environment variable - Pass training_elastic_mode from CLI args to job context in master initialization The change centralizes configuration of the training elastic mode (base/ucp) through the master's command-line interface and global context, replacing the previous environment variable approach for better consistency and configurability. * fix ci & unit test * feat: add training elastic mode argument to job master Add a new constant `trainingElasticModeArg` to the master controller and include it in the list of master arguments. This allows the job master to accept a `--training_elastic_mode` flag, enabling support for different training elasticity modes such as 'ucp'. The test has been updated to verify that the new argument is correctly passed to the master pod command. --------- Co-authored-by: 01191421 <lijialin1014@cmbchina.com> Co-authored-by: Tianyi Chen <chentianyi.cty@antfin.com>
1 parent 8abc858 commit 79343ef

File tree

32 files changed

+1679
-56
lines changed

32 files changed

+1679
-56
lines changed

dlrover/python/common/comm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ class GlobalStep(Message):
223223
elapsed_time_per_step: float = 0.0
224224

225225

226+
@dataclass
227+
class RdzvBlocked(Message):
228+
"""Indicate whether rendezvous completion is blocked."""
229+
230+
blocked: bool = False
231+
reason: str = ""
232+
233+
226234
@dataclass
227235
class HeartBeat(Message):
228236
timestamp: int = 0

dlrover/python/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ class NodeEnv(object):
315315
RELAUNCHED_POD = "RELAUNCHED_POD"
316316
DLROVER_MASTER_ADDR = "DLROVER_MASTER_ADDR"
317317
DLROVER_MASTER_SERVICE_TYPE = "DLROVER_MASTER_SERVICE_TYPE"
318+
DLROVER_TRAINING_ELASTIC_MODE = "DLROVER_TRAINING_ELASTIC_MODE"
318319
GRPC_ENABLE_FORK = "GRPC_ENABLE_FORK_SUPPORT"
319320
GRPC_POLL_STRATEGY = "GRPC_POLL_STRATEGY"
320321
POD_NAME = "POD_NAME"

dlrover/python/common/global_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class DefaultValues(object):
8282
FIRST_GROUP_IDX = 1000 # group idx initial value for group relaunch
8383
MAX_RELAUNCH_COUNT = 3 # maximum node relaunch count
8484
MAX_GROUP_RELAUNCH_COUNT = 3 # maximum node group relaunch count
85+
TRAINING_ELASTIC_MODE = "base"
8586

8687

8788
class Context(Singleton):
@@ -146,6 +147,7 @@ def __init__(self):
146147
self.pre_check_operators = DefaultValues.PRE_CHECK_OPS
147148
self.max_relaunch_count = DefaultValues.MAX_RELAUNCH_COUNT
148149
self.max_group_relaunch_count = DefaultValues.MAX_GROUP_RELAUNCH_COUNT
150+
self.training_elastic_mode = DefaultValues.TRAINING_ELASTIC_MODE
149151

150152
def set_params_from_brain(self):
151153
self.train_speed_record_num = self.get_param_value_from_brain(

dlrover/python/elastic_agent/master_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,10 @@ def report_event(
515515
)
516516
self._report(message)
517517

518+
def set_rdzv_blocked(self, blocked, reason=""):
519+
message = comm.RdzvBlocked(blocked=blocked, reason=reason)
520+
self._report(message)
521+
518522
@classmethod
519523
def singleton_instance(cls, *args, **kwargs):
520524
if not cls._instance:

dlrover/python/elastic_agent/torch/ckpt_saver.py

Lines changed: 147 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -423,19 +423,22 @@ def __init__(
423423
local_shard_num=1,
424424
global_shard_num=1,
425425
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
426+
rank=0,
426427
) -> None:
427428
logger.info(
428429
"Initializing the AsyncSaver with arguments: "
429430
f"checkpoint_dir={checkpoint_dir}, "
430431
f"local_shard_num={local_shard_num}, "
431432
f"global_shard_num={global_shard_num}, "
432433
f"save_timeout={save_timeout}"
434+
f"rank={rank}"
433435
)
434436
self.checkpoint_dir = checkpoint_dir
435437
self.local_shard_num = local_shard_num
436438
self.global_shard_num = global_shard_num
437-
self._node_rank = env_utils.get_node_rank()
438-
self._is_agent_rank_0 = self._node_rank == 0
439+
self._node_rank = env_utils.get_rank()
440+
self._rank = rank
441+
self._is_agent_rank_0 = self._rank == 0
439442
self._shm_handlers: List[SharedMemoryHandler] = []
440443
self._shm_locks: List[SharedLock] = []
441444
self._stop_commit = False
@@ -508,8 +511,15 @@ def _factory():
508511
and saver_thread
509512
and saver_thread.is_alive()
510513
):
511-
logger.info(
512-
"The saver is already created, skip creating the saver."
514+
# bufix: because the checkpoint dir changes every time of training
515+
cls._saver_instance.checkpoint_dir = (
516+
class_meta.kwargs.get("checkpoint_dir")
517+
)
518+
cls._saver_instance._rank = class_meta.kwargs.get(
519+
"rank"
520+
)
521+
cls._saver_instance._is_agent_rank_0 = (
522+
cls._saver_instance._rank == 0
513523
)
514524
continue
515525

@@ -529,6 +539,33 @@ def _factory():
529539
def get_ckpt_saver(cls):
530540
return cls._saver_instance
531541

542+
def ucp(self, input_dir: str, output_dir: str, ucp_device_type: str):
543+
"""universal checkpoint"""
544+
pass
545+
546+
def get_latest_start_saving_step(self):
547+
step = self._latest_step
548+
steps = []
549+
for shm_handler in self._shm_handlers:
550+
default_config = CheckpointConfig()
551+
config = shm_handler.get_checkpoint_config(default_config)
552+
steps.append(config.step)
553+
if len(steps) > 0:
554+
step = steps[0]
555+
return step
556+
557+
def get_latest_success_save_dir(self):
558+
try:
559+
tracker_file = os.path.join(
560+
self.checkpoint_dir, CheckpointConstant.TRACER_FILE_NAME
561+
)
562+
with open(tracker_file, "r") as f:
563+
step = int(f.read())
564+
except FileNotFoundError:
565+
return None, None
566+
567+
return self.checkpoint_dir, step
568+
532569
@classmethod
533570
def register_signal_handler(cls):
534571
sigint_handler = signal.getsignal(signal.SIGINT)
@@ -709,7 +746,7 @@ def _save_shard(
709746
shm_lock.release()
710747

711748
def _dist_make_dir(self, path, timeout=30):
712-
if self._node_rank == 0:
749+
if self._rank == 0:
713750
logger.info(f"Create path by rank0 worker: {path}.")
714751
self.storage.safe_rmtree(path)
715752
self.storage.safe_makedirs(path)
@@ -719,7 +756,7 @@ def _dist_make_dir(self, path, timeout=30):
719756
return
720757
time.sleep(1)
721758
logger.warning(
722-
f"Worker {self._node_rank} can't find path {path} "
759+
f"Worker {self._rank} can't find path {path} "
723760
f"with timeout {timeout}."
724761
)
725762

@@ -1042,11 +1079,47 @@ def commit_checkpoint(self, step: int, step_done_dir: str, timeout=600):
10421079
def persist_to_storage(
10431080
self, local_shard_id: int, ckpt_config: CheckpointConfig
10441081
):
1082+
if ckpt_config is None or not ckpt_config.paths:
1083+
logger.info(
1084+
"Skip persisting checkpoint because checkpoint config is missing."
1085+
)
1086+
return
10451087
state_dict = self._shm_handlers[local_shard_id].load_state_dict()
1088+
safe_serialization = None
1089+
if "safe_serialization" in state_dict:
1090+
safe_serialization = state_dict.pop("safe_serialization")
10461091
for state_name, sd in state_dict.items():
10471092
if sd and state_name in ckpt_config.paths:
1093+
from transformers.utils import (
1094+
ADAPTER_SAFE_WEIGHTS_NAME,
1095+
ADAPTER_WEIGHTS_NAME,
1096+
SAFE_WEIGHTS_NAME,
1097+
WEIGHTS_NAME,
1098+
)
1099+
from safetensors.torch import save_file as safe_save_file
1100+
import re
1101+
10481102
path = ckpt_config.paths[state_name]
1049-
self.storage.write_state_dict(sd, path, torch.save)
1103+
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}.bin")
1104+
match = reg.fullmatch(state_name)
1105+
if safe_serialization:
1106+
if state_name.endswith(ADAPTER_WEIGHTS_NAME):
1107+
path = path.replace(
1108+
ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME
1109+
)
1110+
self.storage.write_state_dict(sd, path, safe_save_file)
1111+
elif state_name.endswith(WEIGHTS_NAME):
1112+
path = path.replace(WEIGHTS_NAME, SAFE_WEIGHTS_NAME)
1113+
self.storage.write_state_dict(sd, path, safe_save_file)
1114+
elif match:
1115+
path = path.replace("pytorch_model", "model").replace(
1116+
".bin", ".safetensors"
1117+
)
1118+
self.storage.write_state_dict(sd, path, safe_save_file)
1119+
else:
1120+
self.storage.write_state_dict(sd, path, torch.save)
1121+
else:
1122+
self.storage.write_state_dict(sd, path, torch.save)
10501123

10511124

10521125
class TempDirCheckpointSaver(AsyncCheckpointSaver):
@@ -1063,13 +1136,15 @@ def __init__(
10631136
local_shard_num=1,
10641137
global_shard_num=1,
10651138
save_timeout=CheckpointConstant.SAVE_TIMEOUT,
1139+
rank=0,
10661140
) -> None:
10671141
super().__init__(
10681142
checkpoint_dir,
10691143
storage_meta,
10701144
local_shard_num,
10711145
global_shard_num,
10721146
save_timeout,
1147+
rank=rank,
10731148
)
10741149

10751150
if self._node_rank == 0:
@@ -1267,9 +1342,6 @@ class DdpCheckpointSaver(CommonDirCheckpointSaver):
12671342
"""Persist the checkpoint from CPU memory buffer into the storage."""
12681343

12691344
def persist_to_storage(self, local_shard_id: int, ckpt_config):
1270-
if self._node_rank != 0:
1271-
logger.info("Skip and only rank 0 saves checkpoint in a DDP job.")
1272-
return
12731345
super().persist_to_storage(local_shard_id, ckpt_config)
12741346

12751347

@@ -1310,6 +1382,70 @@ def update_tracker_file(self, step):
13101382
)
13111383
self.storage.write(str(step), ds_tracker_filename)
13121384

1385+
def get_deepspeed_install_dir(self):
1386+
spec = importlib.util.find_spec("deepspeed")
1387+
deepspeed_dir = ""
1388+
if spec and spec.origin:
1389+
# spec.origin 指向 __init__.py 文件
1390+
module_path = spec.origin
1391+
deepspeed_dir = os.path.dirname(module_path)
1392+
return deepspeed_dir
1393+
1394+
def ucp(self, input_dir: str, output_dir: str, ucp_device_type: str):
1395+
import torch
1396+
from packaging import version
1397+
from torch.distributed.elastic.multiprocessing.api import (
1398+
SubprocessHandler,
1399+
)
1400+
1401+
def version_less_than_230():
1402+
current_version = version.parse(torch.__version__).base_version
1403+
return version.parse(current_version) <= version.parse("2.2.2")
1404+
1405+
def version_less_than_240():
1406+
current_version = version.parse(torch.__version__).base_version
1407+
return version.parse(current_version) <= version.parse("2.3.1")
1408+
1409+
import sys
1410+
import os
1411+
1412+
cmd = os.getenv("PYTHON_EXEC", sys.executable)
1413+
deepspeed_dir = self.get_deepspeed_install_dir()
1414+
args_list = [
1415+
deepspeed_dir + "/checkpoint/ds_to_universal.py",
1416+
"--input_folder",
1417+
f"{input_dir}",
1418+
"--output_folder",
1419+
f"{output_dir}",
1420+
"--inject_missing_state",
1421+
]
1422+
if ucp_device_type != "cpu":
1423+
args_list.extend(["--device", ucp_device_type])
1424+
args = tuple(args_list)
1425+
if version_less_than_230():
1426+
handler = SubprocessHandler(cmd, args, {}, "", "")
1427+
else:
1428+
handler = SubprocessHandler(cmd, args, {}, "", "", 0)
1429+
ret = handler.proc.wait()
1430+
if ret != 0:
1431+
print(f"subprocess returned non-zero exit code{ret}")
1432+
return False
1433+
else:
1434+
return True
1435+
1436+
def get_latest_success_save_dir(self):
1437+
try:
1438+
tracker_file = os.path.join(
1439+
self.checkpoint_dir, CheckpointConstant.TRACER_FILE_NAME
1440+
)
1441+
1442+
with open(tracker_file, "r") as f:
1443+
step = int(f.read())
1444+
except FileNotFoundError:
1445+
return None, None
1446+
1447+
return self.checkpoint_dir, step
1448+
13131449

13141450
class FsdpDcpSaver(CommonDirCheckpointSaver):
13151451
"""The saver saves the distributed checkpoint of FSDP into the storage."""
@@ -1355,3 +1491,4 @@ def persist_to_storage(
13551491
self.checkpoint_dir, CheckpointConstant.TRACER_FILE_NAME
13561492
)
13571493
self.storage.write(str(ckpt_config.step), tracer_file)
1494+
dcp_metadata = meta_dict.get("dcp_metadata", {})

0 commit comments

Comments
 (0)