Skip to content

Commit 6280cb1

Browse files
authored
[checkpointio] support debug log (#6153)
* [checkpointio] support debug log * [checkpointio] refactor async writer api * fix test * fix test
1 parent ab856fd commit 6280cb1

File tree

9 files changed

+33
-54
lines changed

9 files changed

+33
-54
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,10 @@ def save_unsharded_optimizer(
137137
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
138138
if self.coordinator.is_master():
139139
if use_async:
140-
from tensornvme.async_file_io import AsyncFileWriter
141140

142141
from colossalai.utils.safetensors import save_nested
143142

144-
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
145-
save_nested(f_writer, state_dict)
143+
f_writer = save_nested(checkpoint, state_dict)
146144
self.async_writers.append(f_writer)
147145
else:
148146
save_state_dict(state_dict, checkpoint, use_safetensors=False)
@@ -222,16 +220,10 @@ def save_sharded_optimizer(
222220
checkpoint_file_path = os.path.join(checkpoint, shard_file)
223221
if self.coordinator.is_master():
224222
if use_async:
225-
from tensornvme.async_file_io import AsyncFileWriter
226223

227224
from colossalai.utils.safetensors import save_nested
228225

229-
f_writer = AsyncFileWriter(
230-
checkpoint_file_path,
231-
n_entries=self.N_WRITE_ENTRIES,
232-
backend="pthread",
233-
)
234-
save_nested(f_writer, shard)
226+
f_writer = save_nested(checkpoint_file_path, shard)
235227
self.async_writers.append(f_writer)
236228
else:
237229
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

colossalai/checkpoint_io/checkpoint_io_base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ class CheckpointIO(ABC):
5959
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
6060
"""
6161

62-
N_WRITE_ENTRIES: int = 32
63-
6462
# ======================================
6563
# Public methods
6664
# ======================================

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,11 @@ def save_unsharded_model(
5454
pass
5555

5656
if use_async:
57-
from tensornvme.async_file_io import AsyncFileWriter
5857

59-
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
6058
if id(model) not in self.pinned_state_dicts:
6159
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
60+
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
6261
self.async_writers.append(writer)
63-
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
6462

6563
else:
6664
# save the checkpoint
@@ -196,7 +194,6 @@ def save_sharded_model(
196194
base_filename=weights_name,
197195
is_master=True,
198196
pinned_state_dict=pinned_state_dict,
199-
n_write_entries=self.N_WRITE_ENTRIES,
200197
)
201198
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
202199
self.async_writers.extend(writers)

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,15 +686,13 @@ def save_unsharded_model(
686686
for _state_dict in state_dict_list:
687687
complete_state_dict.update(_state_dict)
688688
if use_async:
689-
from tensornvme.async_file_io import AsyncFileWriter
690689

691690
from colossalai.utils.safetensors import move_and_save
692691

693-
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
694692
if id(model) not in self.pinned_state_dicts:
695693
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
694+
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
696695
self.async_writers.append(writer)
697-
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
698696
else:
699697
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
700698

colossalai/checkpoint_io/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ def async_save_state_dict_shards(
273273
base_filename: str,
274274
is_master: bool,
275275
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
276-
n_write_entries: int,
277276
use_pp_format: bool = False,
278277
) -> Tuple[int, Dict[str, torch.Tensor], list]:
279278
"""
@@ -290,7 +289,6 @@ def async_save_state_dict_shards(
290289
Returns:
291290
int: the total size of shards
292291
"""
293-
from tensornvme.async_file_io import AsyncFileWriter
294292

295293
total_size = 0
296294
shard_filenames = []
@@ -311,17 +309,15 @@ def async_save_state_dict_shards(
311309
index_file.append_weight_map(key, shard_file)
312310
checkpoint_file_path = os.path.join(checkpoint, shard_file)
313311

314-
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
315-
writers.append(writer)
316-
317312
if pinned_state_dict is not None:
318313
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
319314
else:
320315
sub_pinned_state_dict = create_pinned_state_dict(shard)
321316
returned_state_dict.update(sub_pinned_state_dict)
322317

323318
# Only save on master rank.
324-
move_and_save(writer, shard, sub_pinned_state_dict)
319+
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
320+
writers.append(writer)
325321
shard_filenames.append(shard_file)
326322
del shard
327323

colossalai/utils/safetensors.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from torch.distributed.distributed_c10d import _pickler, _unpickler
1717

18+
ASYNC_WRITE_ENTRIES = 32
19+
1820

1921
def _object_to_tensor(obj, device):
2022
f = io.BytesIO()
@@ -149,32 +151,31 @@ def prepare(
149151
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
150152

151153

152-
def save(
153-
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
154-
) -> None:
154+
def save(path: str, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> None:
155155
prepared_data, tensors, _ = prepare(state_dict, metadata)
156156
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
157-
157+
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensors))
158158
f_writer.write(n.to_bytes(8, byteorder="little"))
159159
f_writer.write(header_bytes)
160160

161161
for tensor in tensors:
162162
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
163+
return f_writer
163164

164165

165-
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
166+
def save_nested(path: str, state_dict: Dict[str, torch.Tensor]) -> None:
166167
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
167-
save(f_writer, flatten_data, metadata)
168+
return save(path, flatten_data, metadata)
168169

169170

170171
def move_and_save(
171-
f_writer: AsyncFileWriter,
172+
path: str,
172173
state_dict: Dict[str, torch.Tensor],
173174
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
174175
) -> None:
175176
prepared_data, _, tensor_keys = prepare(state_dict)
176177
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
177-
178+
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
178179
f_writer.write(n.to_bytes(8, byteorder="little"))
179180
f_writer.write(header_bytes)
180181

@@ -184,6 +185,7 @@ def move_and_save(
184185
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
185186
else:
186187
f_writer.write_tensor(state_dict[name])
188+
return f_writer
187189

188190

189191
def load_flat(checkpoint_path):

colossalai/zero/low_level/bookkeeping/tensor_bucket.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ def all_gather(self, group=None, fp8_communication: bool = False):
8383
unflat_buffers = list(map(list, zip(*unflat_buffers)))
8484
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
8585
write_back_tensor = self._write_back_pairs[tensor]
86-
write_back_tensor.data.copy_(
87-
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
88-
)
86+
rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()]
87+
if write_back_tensor.is_contiguous():
88+
rec_tensor = rec_tensor.view_as(write_back_tensor)
89+
else:
90+
rec_tensor = rec_tensor.reshape_as(write_back_tensor)
91+
write_back_tensor.data.copy_(rec_tensor)
92+
8993
self.empty()

tests/test_checkpoint_io/test_safetensors_async_io.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,12 @@
33
import torch
44
from safetensors.torch import load_file
55

6-
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
7-
8-
try:
9-
from tensornvme.async_file_io import AsyncFileWriter
10-
except ModuleNotFoundError:
11-
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
12-
13-
14-
from colossalai.testing import check_state_dict_equal
6+
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
157
from colossalai.utils import get_current_device
8+
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
169

1710

11+
@clear_cache_before_run()
1812
def test_save_load():
1913
with tempfile.TemporaryDirectory() as tempdir:
2014
optimizer_state_dict = {
@@ -111,17 +105,15 @@ def test_save_load():
111105
}
112106

113107
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
114-
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
115-
save_nested(f_writer, optimizer_state_dict)
108+
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
116109
f_writer.sync_before_step()
117110
f_writer.synchronize()
118111
del f_writer
119112
load_state_dict = load_flat(optimizer_saved_path)
120113
check_state_dict_equal(load_state_dict, optimizer_state_dict)
121114

122115
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
123-
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
124-
save_nested(f_writer, optimizer_state_dict["state"])
116+
f_writer = save_nested(optimizer_shard_saved_path, optimizer_state_dict["state"])
125117
f_writer.sync_before_step()
126118
f_writer.synchronize()
127119
del f_writer
@@ -134,8 +126,7 @@ def test_save_load():
134126
"module.weight2": torch.rand((1024, 1024)),
135127
}
136128
model_saved_path = f"{tempdir}/save_model.safetensors"
137-
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
138-
save(f_writer, model_state_dict)
129+
f_writer = save(model_saved_path, model_state_dict)
139130
f_writer.sync_before_step()
140131
f_writer.synchronize()
141132
del f_writer
@@ -145,8 +136,7 @@ def test_save_load():
145136
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
146137
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
147138
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
148-
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
149-
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
139+
f_writer = move_and_save(model_saved_path, model_state_dict_cuda, model_state_pinned)
150140
f_writer.sync_before_step()
151141
f_writer.synchronize()
152142
del f_writer

tests/test_optimizer/test_dist_lamb.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from colossalai.nn.optimizer import DistributedLamb, Lamb
1111
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
1212
from colossalai.tensor.d_tensor.api import clear_layout_converter
13-
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
13+
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
1414
from colossalai.testing.random import seed_all
1515
from colossalai.zero import LowLevelZeroOptimizer
1616
from tests.kit.model_zoo import model_zoo
@@ -108,6 +108,7 @@ def set_dist_grad(
108108
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
109109
@parameterize("bias_correction", [False, True])
110110
@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
111+
@clear_cache_before_run()
111112
def run_dist_lamb_basic(
112113
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
113114
) -> None:
@@ -177,6 +178,7 @@ def run_dist_lamb_basic(
177178
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
178179
@parameterize("bias_correction", [False, True])
179180
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)])
181+
@clear_cache_before_run()
180182
def run_dist_lamb_fwd_bwd(
181183
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
182184
) -> None:

0 commit comments

Comments
 (0)