Skip to content

Commit cf519da

Browse files
[optim] hotfix adam load (#6146)
* [optim] hotfix adam load * [checkpointio] fix optimizer async io * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [checkpointio] update test * [checkpointio] update test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5caad13 commit cf519da

File tree

5 files changed

+142
-79
lines changed

5 files changed

+142
-79
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def save_unsharded_optimizer(
142142
from colossalai.utils.safetensors import save_nested
143143

144144
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
145-
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
145+
save_nested(f_writer, state_dict)
146146
self.async_writers.append(f_writer)
147147
else:
148148
save_state_dict(state_dict, checkpoint, use_safetensors=False)

colossalai/nn/optimizer/cpu_adam.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def __init__(
8181
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
8282
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
8383

84+
def load_state_dict(self, state_dict):
85+
super().load_state_dict(state_dict)
86+
for group in self.param_groups:
87+
for p in group["params"]:
88+
state = self.state[p]
89+
if "step" in state and isinstance(state["step"], torch.Tensor):
90+
state["step"] = int(state["step"].item())
91+
8492
def torch_adam_update(
8593
self,
8694
data,

colossalai/testing/comparison.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List, OrderedDict, Tuple
1+
from typing import Any, List, OrderedDict
22

33
import torch
44
import torch.distributed as dist
@@ -78,9 +78,7 @@ def check_state_dict_equal(
7878
v1 = v1.to(v2.dtype)
7979
assert_close_loose(v1, v2)
8080
else:
81-
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
82-
v2 = tuple(v2)
83-
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
81+
assert v1 == v2, f"{v1} not equals to {v2}"
8482

8583

8684
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):

colossalai/utils/safetensors.py

Lines changed: 86 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
22
import json
3-
import warnings
43
from dataclasses import asdict, dataclass
54
from typing import Dict, List, Optional, Tuple
65

@@ -12,6 +11,26 @@
1211
except ModuleNotFoundError:
1312
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
1413
_TYPES_INV = {v: k for k, v in _TYPES.items()}
14+
import io
15+
16+
from torch.distributed.distributed_c10d import _pickler, _unpickler
17+
18+
19+
def _object_to_tensor(obj, device):
20+
f = io.BytesIO()
21+
_pickler(f).dump(obj)
22+
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
23+
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
24+
# Otherwise, it will casue 100X slowdown.
25+
# See: https://github.com/pytorch/pytorch/issues/65696
26+
byte_tensor = torch.ByteTensor(byte_storage).to(device)
27+
return byte_tensor
28+
29+
30+
def _tensor_to_object(tensor, tensor_size):
31+
tensor = tensor.cpu()
32+
buf = tensor.numpy().tobytes()[:tensor_size]
33+
return _unpickler(io.BytesIO(buf)).load()
1534

1635

1736
@dataclass
@@ -28,49 +47,68 @@ class PreparedData:
2847
offset: int
2948

3049

31-
def flatten_dict(nested_dict, parent_key="", separator="^"):
32-
"""
33-
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
34-
35-
nested_dict: The input nested dictionary.
36-
parent_key: The parent key currently being processed.
37-
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
38-
"""
39-
items = []
40-
for k, v in nested_dict.items():
41-
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
42-
if isinstance(v, dict):
43-
items.extend(flatten_dict(v, new_key, separator).items())
44-
else:
45-
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
46-
items.append((new_key, v))
47-
48-
return dict(items)
49-
50-
51-
def unflatten_dict(flattened_dict, separator="^"):
52-
"""
53-
Restore a flattened dictionary back to a multi-level nested dictionary.
54-
55-
flattened_dict: The flattened dictionary.
56-
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
57-
"""
58-
nested_dict = {}
59-
for key, value in flattened_dict.items():
60-
keys = key.split(separator)
61-
try:
62-
keys[0] = int(keys[0])
63-
except ValueError:
64-
warnings.warn(f"{key[0]} can't convert to integer")
65-
d = nested_dict
66-
for part in keys[:-1]:
67-
if part not in d:
68-
d[part] = {}
69-
d = d[part]
70-
assert isinstance(value, torch.Tensor)
71-
d[keys[-1]] = value
72-
73-
return nested_dict
50+
def _cast_to_tensor(obj):
51+
if isinstance(obj, torch.Tensor):
52+
return obj
53+
return _object_to_tensor(obj, "cpu")
54+
55+
56+
def _cast_to_object(tensor: torch.Tensor):
57+
return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())
58+
59+
60+
def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
61+
flat_dict = {}
62+
non_tensor_keys = []
63+
if "state" in state_dict:
64+
# 3-level dict
65+
states = state_dict["state"]
66+
else:
67+
# 2-level dict, usually for optimizer state dict shard
68+
states = state_dict
69+
70+
for idx, d in states.items():
71+
for k, v in d.items():
72+
nested_key = f"state{seperator}{idx}{seperator}{k}"
73+
if not isinstance(v, torch.Tensor):
74+
non_tensor_keys.append(nested_key)
75+
flat_dict[nested_key] = _cast_to_tensor(v)
76+
if "param_groups" in state_dict:
77+
flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
78+
non_tensor_keys.append("param_groups")
79+
if len(non_tensor_keys) > 0:
80+
metadata = {"non_tensor_keys": non_tensor_keys}
81+
else:
82+
metadata = None
83+
return flat_dict, metadata
84+
85+
86+
def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
87+
state_dict = {}
88+
if metadata is not None:
89+
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
90+
else:
91+
non_tensor_keys = []
92+
flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
93+
if "param_groups" in flat_dict:
94+
# 3-level dict
95+
state_dict["param_groups"] = flat_dict.pop("param_groups")
96+
state_dict["state"] = {}
97+
states = state_dict["state"]
98+
else:
99+
# 2-level dict, usually for optimizer state dict shard
100+
states = state_dict
101+
102+
for k, v in flat_dict.items():
103+
parts = k.split(seperator)
104+
assert len(parts) == 3 and parts[0] == "state"
105+
idx = int(parts[1])
106+
key = parts[2]
107+
if idx not in states:
108+
states[idx] = {}
109+
states[idx][key] = v
110+
111+
return state_dict
74112

75113

76114
def prepare(
@@ -124,10 +162,8 @@ def save(
124162
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
125163

126164

127-
def save_nested(
128-
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
129-
) -> None:
130-
flatten_data = flatten_dict(state_dict)
165+
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
166+
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
131167
save(f_writer, flatten_data, metadata)
132168

133169

@@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
154190
with safe_open(checkpoint_path, framework="pt") as f:
155191
metadata = f.metadata()
156192
state_dict_load = load_file(checkpoint_path)
157-
state_dict = unflatten_dict(state_dict_load)
158-
if metadata is None:
159-
return state_dict
160-
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
161-
combined_state_dict = {"state": state_dict}
162-
combined_state_dict.update(metadata)
163-
return combined_state_dict
193+
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
194+
return state_dict

tests/test_checkpoint_io/test_safetensors_async_io.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
11
import tempfile
2-
from copy import deepcopy
32

43
import torch
4+
from safetensors.torch import load_file
55

6-
from colossalai.utils.safetensors import load_flat, save_nested
6+
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
77

88
try:
99
from tensornvme.async_file_io import AsyncFileWriter
1010
except ModuleNotFoundError:
1111
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
1212

1313
from colossalai.testing import check_state_dict_equal
14+
from colossalai.utils import get_current_device
1415

1516

1617
def test_save_load():
1718
with tempfile.TemporaryDirectory() as tempdir:
1819
optimizer_state_dict = {
19-
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
20-
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
21-
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
22-
}
23-
# group_dict = {"param_groups": [0, 1, 2]}
24-
group_dict = {
20+
"state": {
21+
0: {
22+
"step": torch.tensor(1.0),
23+
"exp_avg": torch.rand((1024, 1024)),
24+
"exp_avg_sq": torch.rand((1024, 1024)),
25+
},
26+
1: {
27+
"step": torch.tensor(1.0),
28+
"exp_avg": torch.rand((1024, 1024)),
29+
"exp_avg_sq": torch.rand((1024, 1024)),
30+
},
31+
2: {
32+
"step": torch.tensor(1.0),
33+
"exp_avg": torch.rand((1024, 1024)),
34+
"exp_avg_sq": torch.rand((1024, 1024)),
35+
},
36+
},
2537
"param_groups": [
2638
{
2739
"lr": 0.001,
@@ -94,22 +106,26 @@ def test_save_load():
94106
61,
95107
],
96108
}
97-
]
109+
],
98110
}
99-
metadata = deepcopy(group_dict)
111+
100112
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
101113
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
102-
103-
save_nested(f_writer, optimizer_state_dict, metadata)
114+
save_nested(f_writer, optimizer_state_dict)
104115
f_writer.sync_before_step()
105116
f_writer.synchronize()
106117
f_writer.fp.close()
107-
108118
load_state_dict = load_flat(optimizer_saved_path)
109-
state_dict = load_state_dict["state"]
110-
group = {"param_groups": load_state_dict["param_groups"]}
111-
check_state_dict_equal(optimizer_state_dict, state_dict)
112-
check_state_dict_equal(group_dict, group)
119+
check_state_dict_equal(load_state_dict, optimizer_state_dict)
120+
121+
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
122+
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
123+
save_nested(f_writer, optimizer_state_dict["state"])
124+
f_writer.sync_before_step()
125+
f_writer.synchronize()
126+
f_writer.fp.close()
127+
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
128+
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
113129

114130
model_state_dict = {
115131
"module.weight0": torch.rand((1024, 1024)),
@@ -118,10 +134,20 @@ def test_save_load():
118134
}
119135
model_saved_path = f"{tempdir}/save_model.safetensors"
120136
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
121-
save_nested(f_writer, model_state_dict)
137+
save(f_writer, model_state_dict)
122138
f_writer.sync_before_step()
123139
f_writer.synchronize()
124140
f_writer.fp.close()
141+
load_state_dict = load_file(model_saved_path)
142+
check_state_dict_equal(model_state_dict, load_state_dict)
125143

126-
load_state_dict = load_flat(model_saved_path)
144+
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
145+
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
146+
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
147+
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
148+
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
149+
f_writer.sync_before_step()
150+
f_writer.synchronize()
151+
f_writer.fp.close()
152+
load_state_dict = load_file(model_saved_path)
127153
check_state_dict_equal(model_state_dict, load_state_dict)

0 commit comments

Comments
 (0)