Skip to content

Commit ee81366

Browse files
[checkpointio] support load-pin overlap (#6177)
* [checkpointio] support load-pin overlap * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [test] add conftest --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 479067e commit ee81366

File tree

6 files changed

+56
-32
lines changed

6 files changed

+56
-32
lines changed

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
create_pinned_state_dict,
2121
get_model_base_filenames,
2222
get_optimizer_base_filenames,
23-
load_shard_state_dict,
23+
load_state_dict_shards,
2424
save_config_file,
2525
save_state_dict,
2626
save_state_dict_shards,
@@ -29,7 +29,6 @@
2929
from colossalai.interface import ModelWrapper, OptimizerWrapper
3030
from colossalai.logging import get_dist_logger
3131
from colossalai.shardformer import ShardConfig, ShardFormer
32-
from colossalai.utils.safetensors import load_flat
3332
from colossalai.zero import GeminiDDP, GeminiOptimizer
3433
from colossalai.zero.gemini.memory_tracer import MemStats
3534

@@ -350,11 +349,9 @@ def load_sharded_optimizer(
350349

351350
# Load optimizer states from shard files under checkpoint path.
352351
# For each file, only load the states managed by current process.
353-
for shard_file in checkpoint_files:
354-
if shard_file.endswith(".safetensors"):
355-
state_dict_shard = load_flat(shard_file)
356-
else:
357-
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
352+
for state_dict_shard in load_state_dict_shards(
353+
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
354+
):
358355
if not low_cpu_mem_mode:
359356
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
360357
optimizer.load_param_states(state_dict_shard)

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
get_optimizer_base_filenames,
2525
get_shard_filename,
2626
load_param_groups_into_optimizer,
27-
load_shard_state_dict,
2827
load_state_dict,
28+
load_state_dict_shards,
2929
load_states_into_optimizer,
3030
save_param_groups,
3131
save_state_dict,
@@ -276,13 +276,7 @@ def load_sharded_optimizer(
276276

277277
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
278278

279-
for shard_file in checkpoint_files:
280-
if shard_file.endswith(".safetensors"):
281-
from colossalai.utils.safetensors import load_flat
282-
283-
state_dict = load_flat(shard_file)
284-
else:
285-
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
279+
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
286280
# shard state dict
287281
for param_idx, state in state_dict.items():
288282
for k, v in state.items():

colossalai/booster/plugin/torch_fsdp_plugin.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ def load_sharded_model(
255255
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
256256

257257
fsdp_state_dict = {}
258-
for shard_file in checkpoint_files:
259-
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
258+
for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
259+
fsdp_state_dict.update(state_dict)
260260

261261
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
262262
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
@@ -388,11 +388,7 @@ def load_sharded_optimizer(
388388
# Load param
389389
fsdp_optim_state = {}
390390
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
391-
for shard_file in checkpoint_files:
392-
if shard_file.endswith(".safetensors"):
393-
state_dict_shard = load_flat(shard_file, seperator=".")
394-
else:
395-
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
391+
for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
396392
fsdp_optim_state.update(state_dict_shard)
397393

398394
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
get_optimizer_base_filenames,
1919
is_safetensors_available,
2020
load_param_groups_into_optimizer,
21-
load_shard_state_dict,
2221
load_state_dict,
2322
load_state_dict_into_model,
23+
load_state_dict_shards,
2424
load_states_into_optimizer,
2525
save_config_file,
2626
save_param_groups,
@@ -94,11 +94,7 @@ def load_sharded_optimizer(
9494

9595
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
9696

97-
for shard_file in checkpoint_files:
98-
if shard_file.endswith(".safetensors"):
99-
state_dict = load_flat(shard_file)
100-
else:
101-
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
97+
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
10298
if not low_cpu_mem_mode:
10399
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
104100
load_states_into_optimizer(optimizer, state_dict, id_map)
@@ -295,8 +291,7 @@ def load_sharded_model(
295291
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
296292
missing_keys = []
297293

298-
for shard_file in checkpoint_files:
299-
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
294+
for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
300295
if not low_cpu_mem_mode:
301296
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
302297
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)

colossalai/checkpoint_io/utils.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import defaultdict
77
from itertools import chain
88
from pathlib import Path
9-
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
9+
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
1010

1111
import torch
1212
import torch.nn as nn
@@ -21,7 +21,7 @@
2121
to_global,
2222
to_global_for_customized_distributed_tensor,
2323
)
24-
from colossalai.utils.safetensors import _flatten_optim_state_dict
24+
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
2525

2626
SAFE_WEIGHTS_NAME = "model.safetensors"
2727
WEIGHTS_NAME = "pytorch_model.bin"
@@ -972,3 +972,35 @@ def create_pinned_state_dict(
972972
idx = future_to_idx[future]
973973
elems[idx] = future.result()
974974
return tree_unflatten(elems, spec)
975+
976+
977+
def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
978+
if is_optim:
979+
if path.endswith(".safetensors"):
980+
state_dict = load_flat(path)
981+
else:
982+
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
983+
else:
984+
state_dict = load_shard_state_dict(Path(path), use_safetensors)
985+
return state_dict
986+
987+
988+
def load_state_dict_shards(
989+
checkpoint_files: List[str],
990+
is_optim: bool,
991+
use_safetensors: bool,
992+
low_cpu_mem_mode: bool = True,
993+
prefetch: int = 3,
994+
) -> Generator[dict, None, None]:
995+
if low_cpu_mem_mode:
996+
for shard_file in checkpoint_files:
997+
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
998+
yield state_dict
999+
else:
1000+
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
1001+
futures = []
1002+
for shard_file in checkpoint_files:
1003+
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
1004+
futures.append(future)
1005+
for future in concurrent.futures.as_completed(futures):
1006+
yield future.result()

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import gc
2+
3+
from colossalai.accelerator import get_accelerator
4+
5+
6+
def pytest_runtest_setup(item):
7+
# called for running each test in 'a' directory
8+
accelerator = get_accelerator()
9+
accelerator.empty_cache()
10+
gc.collect()

0 commit comments

Comments
 (0)