Skip to content

Commit c2e8f61

Browse files
authored
[checkpointio] fix hybrid plugin model save (#6106)
1 parent 89a9a60 commit c2e8f61

File tree

4 files changed

+41
-38
lines changed

4 files changed

+41
-38
lines changed

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
to_padded_tensor,
2222
to_unpadded_tensor,
2323
)
24-
from colossalai.utils import get_current_device
24+
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
2525

2626
from .general_checkpoint_io import GeneralCheckpointIO
2727
from .index_file import CheckpointIndexFile
@@ -105,8 +105,9 @@ def _model_sharder(
105105
yield block, block_size
106106

107107
# Save buffers.
108+
non_persist_buffers_set = get_non_persistent_buffers_set(model)
108109
for name, buf in model.named_buffers():
109-
if buf is not None and name not in model._non_persistent_buffers_set:
110+
if buf is not None and name not in non_persist_buffers_set:
110111
buffer = buf if keep_vars else buf.detach()
111112
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
112113
if block is not None:
@@ -352,9 +353,7 @@ def _load(name: str):
352353
_load(name)
353354

354355
# Load buffers.
355-
non_persistent_buffers = set()
356-
for n, m in model.named_modules():
357-
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
356+
non_persistent_buffers = get_non_persistent_buffers_set(model)
358357
for name, buf in model.named_buffers():
359358
if buf is not None and name not in non_persistent_buffers:
360359
_load(name)

colossalai/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
ensure_path_exists,
66
free_storage,
77
get_current_device,
8+
get_non_persistent_buffers_set,
89
is_ddp_ignored,
910
set_seed,
1011
)
@@ -25,4 +26,5 @@
2526
"set_seed",
2627
"get_current_device",
2728
"is_ddp_ignored",
29+
"get_non_persistent_buffers_set",
2830
]

colossalai/utils/common.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import random
66
from contextlib import contextmanager
77
from pathlib import Path
8-
from typing import Callable
8+
from typing import Callable, Optional, Set
99

1010
import numpy as np
1111
import torch
12+
import torch.nn as nn
1213

1314
from colossalai.accelerator import get_accelerator
1415

@@ -76,3 +77,34 @@ def set_seed(seed):
7677
random.seed(seed)
7778
np.random.seed(seed)
7879
torch.manual_seed(seed)
80+
81+
82+
def get_non_persistent_buffers_set(
83+
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
84+
):
85+
r"""
86+
Args:
87+
memo: a memo to store the set of modules already added to the result
88+
prefix: a prefix that will be added to the name of the module
89+
remove_duplicate: whether to remove the duplicated module instances in the result
90+
or not
91+
"""
92+
93+
if memo is None:
94+
memo = set()
95+
self_non_persistent_set = set()
96+
if module not in memo:
97+
if remove_duplicate:
98+
memo.add(module)
99+
self_non_persistent_set = set(
100+
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
101+
)
102+
for name, sub_module in module._modules.items():
103+
if sub_module is None:
104+
continue
105+
submodule_prefix = prefix + ("." if prefix else "") + name
106+
child_non_persistent_set = get_non_persistent_buffers_set(
107+
sub_module, memo, submodule_prefix, remove_duplicate
108+
)
109+
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
110+
return self_non_persistent_set

colossalai/zero/gemini/gemini_ddp.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
to_unpadded_tensor,
3636
)
3737
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
38-
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
38+
from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored
3939

4040
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
4141
from .gemini_hook import GeminiZeROHook
@@ -187,7 +187,7 @@ def __init__(
187187
pin_memory=pin_memory,
188188
)
189189
super().__init__(module)
190-
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
190+
self._non_persistent_buffers_set = get_non_persistent_buffers_set(module)
191191
self._cast_buffers()
192192

193193
# register grad hook
@@ -257,36 +257,6 @@ def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
257257
for p in params_to_ignore:
258258
p._ddp_to_ignore = True
259259

260-
def _get_non_persistent_buffers_set(
261-
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
262-
):
263-
r"""
264-
Args:
265-
memo: a memo to store the set of modules already added to the result
266-
prefix: a prefix that will be added to the name of the module
267-
remove_duplicate: whether to remove the duplicated module instances in the result
268-
or not
269-
"""
270-
271-
if memo is None:
272-
memo = set()
273-
self_non_persistent_set = set()
274-
if module not in memo:
275-
if remove_duplicate:
276-
memo.add(module)
277-
self_non_persistent_set = set(
278-
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
279-
)
280-
for name, sub_module in module._modules.items():
281-
if sub_module is None:
282-
continue
283-
submodule_prefix = prefix + ("." if prefix else "") + name
284-
child_non_persistent_set = self._get_non_persistent_buffers_set(
285-
sub_module, memo, submodule_prefix, remove_duplicate
286-
)
287-
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
288-
return self_non_persistent_set
289-
290260
def _post_forward(self):
291261
"""This function is only triggered for inference."""
292262
access_list = list(self.chunk_manager.accessed_chunks)

0 commit comments

Comments
 (0)