Skip to content

Commit e994c64

Browse files
authored
[checkpointio] fix async io (#6155)
1 parent de3d371 commit e994c64

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import torch.nn as nn
99
from torch.optim import Optimizer
1010

11-
from colossalai.utils.safetensors import move_and_save
12-
1311
from .checkpoint_io_base import CheckpointIO
1412
from .index_file import CheckpointIndexFile
1513
from .utils import (
@@ -54,6 +52,7 @@ def save_unsharded_model(
5452
pass
5553

5654
if use_async:
55+
from colossalai.utils.safetensors import move_and_save
5756

5857
if id(model) not in self.pinned_state_dicts:
5958
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)

colossalai/checkpoint_io/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
to_global,
2020
to_global_for_customized_distributed_tensor,
2121
)
22-
from colossalai.utils.safetensors import move_and_save
2322

2423
SAFE_WEIGHTS_NAME = "model.safetensors"
2524
WEIGHTS_NAME = "pytorch_model.bin"
@@ -289,6 +288,7 @@ def async_save_state_dict_shards(
289288
Returns:
290289
int: the total size of shards
291290
"""
291+
from colossalai.utils.safetensors import move_and_save
292292

293293
total_size = 0
294294
shard_filenames = []

0 commit comments

Comments
 (0)