Skip to content

Commit 0136312

Browse files
authored
[Unified Checkpoint] update async save (#8801)
* update async save * move load_unified_optimizer * update async save * tmp * add config * update final async save * add load_non_merge_optimizer * update unlink func * update final signal * update windows for async_save * fix rng load * fix unittest * fix distdataloader, update async_save * update unlink * remove dist dataloader * fix load optimizer oom
1 parent 27f8462 commit 0136312

File tree

5 files changed

+858
-404
lines changed

5 files changed

+858
-404
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Shared Memory Utils"""
15+
16+
from dataclasses import dataclass
17+
from typing import List, Mapping, Tuple
18+
19+
import numpy as np
20+
import paddle
21+
22+
from paddlenlp.transformers.utils import device_guard
23+
24+
25+
@dataclass
26+
class TensorMeta:
27+
shape: Tuple[int] = None # type: ignore
28+
dtype: paddle.dtype = None # type: ignore
29+
element_size: int = 0
30+
numel: int = 0
31+
offset: int = 0
32+
33+
34+
dtype_mapping = {
35+
paddle.float32: np.float32,
36+
paddle.float64: np.float64,
37+
paddle.int32: np.int32,
38+
paddle.int64: np.int64,
39+
paddle.uint8: np.uint8,
40+
paddle.bool: np.bool_,
41+
paddle.float16: np.float16,
42+
paddle.bfloat16: np.uint16,
43+
paddle.complex64: np.complex64,
44+
paddle.complex128: np.complex128,
45+
}
46+
47+
48+
def _write_shared_memory(value: paddle.Tensor, meta: TensorMeta, buffer):
49+
"""
50+
Write a CPU tensor into the shared memory.
51+
"""
52+
if value.numel() == 0:
53+
return
54+
shm_numpy = np.frombuffer(
55+
buffer, dtype=dtype_mapping[value.dtype], count=int(value.numel()), offset=int(meta.offset)
56+
)
57+
with device_guard("cpu"):
58+
shm_tensor = paddle.Tensor(shm_numpy, zero_copy=True).reshape(value.shape)
59+
shm_tensor.copy_(value, False)
60+
61+
62+
def _traverse_copy_to_shm(value, meta, buffer):
63+
if isinstance(value, Mapping):
64+
for k, v in value.items():
65+
if isinstance(v, (Mapping, List)):
66+
m = meta[k]
67+
_traverse_copy_to_shm(v, m, buffer)
68+
elif paddle.is_tensor(v):
69+
m = meta[k]
70+
_write_shared_memory(v, m, buffer)
71+
else:
72+
meta[k] = v
73+
elif isinstance(value, List):
74+
for i, v in enumerate(value):
75+
if isinstance(v, (Mapping, List)):
76+
m = meta[i]
77+
_traverse_copy_to_shm(v, m, buffer)
78+
elif paddle.is_tensor(v):
79+
m = meta[i]
80+
_write_shared_memory(v, m, buffer)
81+
else:
82+
meta[i] = v
83+
84+
85+
def _read_ndarray_from_buf(value, shm_tensor_buffer):
86+
"""
87+
Read a numpy array from the buffer of shared memory.
88+
"""
89+
if isinstance(value, TensorMeta):
90+
if value.numel == 0:
91+
return np.array([], dtype=dtype_mapping[value.dtype])
92+
else:
93+
shm_numpy = np.frombuffer(
94+
buffer=shm_tensor_buffer.buf,
95+
dtype=dtype_mapping[value.dtype],
96+
offset=value.offset,
97+
count=value.numel,
98+
).reshape(value.shape)
99+
return shm_numpy
100+
else:
101+
return value
102+
103+
104+
def _read_state_dict_from_shm(meta_dict, tensor_shm):
105+
state_dict = _traverse_state_dict(
106+
meta_dict,
107+
lambda x: _read_ndarray_from_buf(x, tensor_shm),
108+
)
109+
return state_dict
110+
111+
112+
def _traverse_state_dict(value, visitor):
113+
"""
114+
Invoke ``visitor`` for each value recursively in ``state_dict``.
115+
"""
116+
if isinstance(value, Mapping):
117+
temp_dict = {}
118+
for k, v in value.items():
119+
temp_dict[k] = _traverse_state_dict(v, visitor)
120+
return temp_dict
121+
elif isinstance(value, List):
122+
temp_list = []
123+
for _, v in enumerate(value):
124+
temp_list.append(_traverse_state_dict(v, visitor))
125+
return temp_list
126+
else:
127+
return visitor(value)
128+
129+
130+
def create_meta_dict(state_dict):
131+
buffer_size = 0
132+
133+
def _create_tensor_meta(value: paddle.Tensor):
134+
nonlocal buffer_size
135+
if not paddle.is_tensor(value):
136+
return value
137+
meta = TensorMeta(
138+
shape=tuple(value.shape), # type: ignore
139+
dtype=value.dtype,
140+
element_size=value.element_size(),
141+
numel=int(value.numel()),
142+
offset=int(buffer_size),
143+
)
144+
buffer_size += value.numel() * value.element_size()
145+
return meta
146+
147+
meta_dict = _traverse_state_dict(state_dict, _create_tensor_meta)
148+
return meta_dict, buffer_size

0 commit comments

Comments
 (0)