Skip to content

Commit 8fddbab

Browse files
committed
[checkpointio] disable buffering
1 parent cf519da commit 8fddbab

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def save_unsharded_optimizer(
141141

142142
from colossalai.utils.safetensors import save_nested
143143

144-
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
144+
f_writer = AsyncFileWriter(
145+
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
146+
)
145147
save_nested(f_writer, state_dict)
146148
self.async_writers.append(f_writer)
147149
else:
@@ -225,7 +227,9 @@ def save_sharded_optimizer(
225227
from colossalai.utils.safetensors import save_nested
226228

227229
f_writer = AsyncFileWriter(
228-
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
230+
fp=open(checkpoint_file_path, "wb", buffering=0),
231+
n_entries=self.N_WRITE_ENTRIES,
232+
backend="pthread",
229233
)
230234
save_nested(f_writer, shard)
231235
self.async_writers.append(f_writer)

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def save_unsharded_model(
5656
if use_async:
5757
from tensornvme.async_file_io import AsyncFileWriter
5858

59-
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
59+
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
6060
if id(model) not in self.pinned_state_dicts:
6161
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
6262
self.async_writers.append(writer)

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,9 @@ def save_unsharded_model(
690690

691691
from colossalai.utils.safetensors import move_and_save
692692

693-
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
693+
writer = AsyncFileWriter(
694+
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
695+
)
694696
if id(model) not in self.pinned_state_dicts:
695697
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
696698
self.async_writers.append(writer)

colossalai/checkpoint_io/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def async_save_state_dict_shards(
311311
index_file.append_weight_map(key, shard_file)
312312
checkpoint_file_path = os.path.join(checkpoint, shard_file)
313313

314-
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
314+
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
315315
writers.append(writer)
316316

317317
if pinned_state_dict is not None:

0 commit comments

Comments
 (0)