Skip to content

Commit d1e62ff

Browse files
mauryaavinash95amauryasfc-gh-truwase
authored
Add DataStates-LLM: Asynchronous Checkpointing Engine Support (#7166)
We are a team at Argonne National Laboratory working on low-overhead asynchronous checkpointing approaches for LLMs and transformers. As part of these efforts, we have developed DataStates-LLM, a library that we would like to contribute to the DeepSpeed community: https://github.com/datastates/datastates-llm The key idea we leverage is to allow non-blocking tensor copies during the forward and backward pass from the GPU to the host. Only if these copies do not finish until the update phase, then we block. Meanwhile, from the host memory, the tensors are flushed asynchronously to durable storage (parallel file systems, local SSDs, etc). To enable this capability, our initial implementation makes the scheduler aware of checkpointing, calling a ckpt.wait() primitive before starting the update phase. We illustrated this with the pipeline scheduler. We are also considering a scheduler-independent solution that integrates with DeepSpeed/Megatron and provides a hook for the start of the update phase, which we can leverage to run ckpt.wait(). We appreciate your feedback and look forward to a collaboration in this space. --------- Signed-off-by: amaurya <[email protected]> Co-authored-by: amaurya <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 64c0052 commit d1e62ff

File tree

11 files changed

+183
-5
lines changed

11 files changed

+183
-5
lines changed

deepspeed/datastates/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# DataStates-LLM checkpointing engine.
2+
3+
This feature is not enabled by default. To enable, set the following options in ds_config.json and download the [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md).
4+
5+
```
6+
{
7+
... other deepspeed config options,
8+
"datastates_ckpt": {
9+
"host_cache_size": 16
10+
}
11+
}
12+
```

deepspeed/datastates/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team

deepspeed/datastates/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
7+
8+
from deepspeed.runtime.config_utils import DeepSpeedConfigObject
9+
import copy
10+
11+
DATASTATES_CHECKPOINTING = "datastates_ckpt"
12+
DATASTATES_CHECKPOINTING_ENABLED = False
13+
14+
15+
class DeepSpeedDataStatesConfig(DeepSpeedConfigObject):
16+
17+
def __init__(self, param_dict):
18+
super(DeepSpeedDataStatesConfig, self).__init__()
19+
20+
self.enabled = param_dict.get(DATASTATES_CHECKPOINTING, DATASTATES_CHECKPOINTING_ENABLED) is not False
21+
self.config = copy.deepcopy(param_dict.get(DATASTATES_CHECKPOINTING, None))

deepspeed/runtime/checkpoint_engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
from .torch_checkpoint_engine import TorchCheckpointEngine
99
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
1010
from .checkpoint_engine import CheckpointCommitInfo
11+
from .datastates_checkpoint_engine import DataStatesCheckpointEngine
1112
from .utils import create_checkpoint_engine

deepspeed/runtime/checkpoint_engine/checkpoint_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ def get_commit_info(self):
5858

5959
def cleanup(self):
6060
pass
61+
62+
def preserves_storage_sharing(self):
63+
return True
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
7+
8+
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
9+
CheckpointEngine, CheckpointCommitInfo
10+
11+
ENGINE_NAME = "DataStatesCheckpointEngine"
12+
13+
14+
class DataStatesCheckpointEngine(CheckpointEngine):
15+
16+
def __init__(self, deepspeed_config, rank):
17+
super().__init__(deepspeed_config)
18+
self.commit_info = None
19+
self.ckpt_engine = None
20+
try:
21+
from datastates import CheckpointEngine as DataStatesEngine
22+
self.ckpt_engine = DataStatesEngine(deepspeed_config, rank)
23+
except ImportError:
24+
raise RuntimeError("Please install DataStates from https://github.com/DataStates/datastates-llm.")
25+
except Exception as e:
26+
raise RuntimeError(f"An error occurred while initializing DataStates Checkpoint Engine: {e}")
27+
28+
def __del__(self):
29+
self.cleanup()
30+
31+
def create(self, info: CheckpointCommitInfo):
32+
self.commit_info = info
33+
return None
34+
35+
def save(self, state_dict, path: str):
36+
return self.ckpt_engine.save(state_dict, path)
37+
38+
def load(self, path: str, map_location=None):
39+
return self.ckpt_engine.load(path, map_location)
40+
41+
def commit(self, info: CheckpointCommitInfo):
42+
if info is None:
43+
return
44+
assert info == self.commit_info
45+
self.ckpt_engine.wait(persist=True)
46+
self.commit_info = None
47+
return True
48+
49+
def cleanup(self):
50+
self.commit(self.commit_info)
51+
if self.ckpt_engine:
52+
self.ckpt_engine.wait(persist=True)
53+
del self.ckpt_engine
54+
55+
def is_decoupled(self):
56+
return True
57+
58+
def preserves_storage_sharing(self):
59+
return False

deepspeed/runtime/checkpoint_engine/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from deepspeed.runtime.model_checkpointing.constants import *
77
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config
88
from deepspeed.utils import logger
9-
9+
from deepspeed import comm as dist
1010
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
1111
from .fast_checkpoint_engine import FastCheckpointEngine
1212
from .torch_checkpoint_engine import TorchCheckpointEngine
@@ -35,4 +35,14 @@ def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers,
3535
else:
3636
return NebulaCheckpointEngine(config_params=config_params.nebula_config)
3737

38+
if config_params.datastates_config.enabled:
39+
try:
40+
from .datastates_checkpoint_engine import DataStatesCheckpointEngine
41+
return DataStatesCheckpointEngine(deepspeed_config=config_params, rank=dist.get_rank())
42+
except ImportError as err:
43+
logger.error(
44+
f"No datastates engine found! Install from https://github.com/DataStates/datastates-llm. Will fall back to torch.save. Details: {err}"
45+
)
46+
return TorchCheckpointEngine(config_params)
47+
3848
return TorchCheckpointEngine(config_params)

deepspeed/runtime/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from ..profiling.config import DeepSpeedFlopsProfilerConfig
5353
from ..autotuning.config import DeepSpeedAutotuningConfig
5454
from ..nebula.config import DeepSpeedNebulaConfig
55+
from ..datastates.config import DeepSpeedDataStatesConfig
5556

5657
from ..compression.config import get_compression_config, get_quantize_enabled
5758
from ..compression.constants import *
@@ -859,6 +860,7 @@ def _initialize_params(self, param_dict):
859860
self.dataloader_drop_last = get_dataloader_drop_last(param_dict)
860861

861862
self.nebula_config = DeepSpeedNebulaConfig(param_dict)
863+
self.datastates_config = DeepSpeedDataStatesConfig(param_dict)
862864
self.checkpoint_config = get_checkpoint_config(param_dict)
863865

864866
self.weight_quantization_config = WeightQuantConfig(

deepspeed/runtime/engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3612,7 +3612,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36123612
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
36133613
if self.random_ltd_enabled():
36143614
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
3615-
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
3615+
saveable_state_dict = expert_state_dict
3616+
if self.checkpoint_engine.preserves_storage_sharing():
3617+
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
36163618
self.checkpoint_engine.save(saveable_state_dict, moe_save_path)
36173619
moe_layer_id += 1
36183620

@@ -3634,7 +3636,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36343636
}
36353637
# TODO: why use BufferedWriter not the path
36363638
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
3637-
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
3639+
saveable_state_dict = optimizer_state
3640+
if self.checkpoint_engine.preserves_storage_sharing():
3641+
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
36383642
self.checkpoint_engine.save(saveable_state_dict, file_path)
36393643

36403644
# Load flow uses below saved file for model parameters, RNG and more
@@ -3674,7 +3678,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36743678
}
36753679
state.update(client_state)
36763680
logger.info(f'Saving model checkpoint: {save_path}')
3677-
saveable_state_dict = clone_tensors_for_torch_save(state)
3681+
savable_state_dict = state
3682+
if self.checkpoint_engine.preserves_storage_sharing():
3683+
saveable_state_dict = clone_tensors_for_torch_save(state)
36783684
self.checkpoint_engine.save(saveable_state_dict, save_path)
36793685

36803686
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):

deepspeed/runtime/pipe/module.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
621621
layer_list = self.forward_funcs[start:end]
622622

623623
checkpoint_engine.makedirs(save_dir, exist_ok=True)
624+
should_clone = checkpoint_engine.preserves_storage_sharing()
624625
for idx, layer in enumerate(layer_list):
625626
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
626627
if not hasattr(layer, 'state_dict'):
@@ -630,7 +631,9 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
630631
if exclude_frozen_params:
631632
for n in self._get_frozen_parameter_names(layer):
632633
del orig_state_dict[n]
633-
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634+
final_state_dict = orig_state_dict
635+
if should_clone:
636+
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634637
checkpoint_engine.save(state_dict=final_state_dict, path=model_ckpt_path)
635638

636639
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):

0 commit comments

Comments
 (0)