|
6 | 6 | from collections import defaultdict
|
7 | 7 | from itertools import chain
|
8 | 8 | from pathlib import Path
|
9 |
| -from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union |
| 9 | +from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union |
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 | import torch.nn as nn
|
|
21 | 21 | to_global,
|
22 | 22 | to_global_for_customized_distributed_tensor,
|
23 | 23 | )
|
24 |
| -from colossalai.utils.safetensors import _flatten_optim_state_dict |
| 24 | +from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat |
25 | 25 |
|
26 | 26 | SAFE_WEIGHTS_NAME = "model.safetensors"
|
27 | 27 | WEIGHTS_NAME = "pytorch_model.bin"
|
@@ -972,3 +972,35 @@ def create_pinned_state_dict(
|
972 | 972 | idx = future_to_idx[future]
|
973 | 973 | elems[idx] = future.result()
|
974 | 974 | return tree_unflatten(elems, spec)
|
| 975 | + |
| 976 | + |
| 977 | +def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict: |
| 978 | + if is_optim: |
| 979 | + if path.endswith(".safetensors"): |
| 980 | + state_dict = load_flat(path) |
| 981 | + else: |
| 982 | + state_dict = load_shard_state_dict(Path(path), use_safetensors=False) |
| 983 | + else: |
| 984 | + state_dict = load_shard_state_dict(Path(path), use_safetensors) |
| 985 | + return state_dict |
| 986 | + |
| 987 | + |
| 988 | +def load_state_dict_shards( |
| 989 | + checkpoint_files: List[str], |
| 990 | + is_optim: bool, |
| 991 | + use_safetensors: bool, |
| 992 | + low_cpu_mem_mode: bool = True, |
| 993 | + prefetch: int = 3, |
| 994 | +) -> Generator[dict, None, None]: |
| 995 | + if low_cpu_mem_mode: |
| 996 | + for shard_file in checkpoint_files: |
| 997 | + state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors) |
| 998 | + yield state_dict |
| 999 | + else: |
| 1000 | + with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor: |
| 1001 | + futures = [] |
| 1002 | + for shard_file in checkpoint_files: |
| 1003 | + future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors) |
| 1004 | + futures.append(future) |
| 1005 | + for future in concurrent.futures.as_completed(futures): |
| 1006 | + yield future.result() |
0 commit comments