File tree Expand file tree Collapse file tree 2 files changed +2
-3
lines changed Expand file tree Collapse file tree 2 files changed +2
-3
lines changed Original file line number Diff line number Diff line change 8
8
import torch .nn as nn
9
9
from torch .optim import Optimizer
10
10
11
- from colossalai .utils .safetensors import move_and_save
12
-
13
11
from .checkpoint_io_base import CheckpointIO
14
12
from .index_file import CheckpointIndexFile
15
13
from .utils import (
@@ -54,6 +52,7 @@ def save_unsharded_model(
54
52
pass
55
53
56
54
if use_async :
55
+ from colossalai .utils .safetensors import move_and_save
57
56
58
57
if id (model ) not in self .pinned_state_dicts :
59
58
self .pinned_state_dicts [id (model )] = create_pinned_state_dict (state_dict )
Original file line number Diff line number Diff line change 19
19
to_global ,
20
20
to_global_for_customized_distributed_tensor ,
21
21
)
22
- from colossalai .utils .safetensors import move_and_save
23
22
24
23
SAFE_WEIGHTS_NAME = "model.safetensors"
25
24
WEIGHTS_NAME = "pytorch_model.bin"
@@ -289,6 +288,7 @@ def async_save_state_dict_shards(
289
288
Returns:
290
289
int: the total size of shards
291
290
"""
291
+ from colossalai .utils .safetensors import move_and_save
292
292
293
293
total_size = 0
294
294
shard_filenames = []
You can’t perform that action at this time.
0 commit comments