3
3
import torch
4
4
from safetensors .torch import load_file
5
5
6
- from colossalai .utils .safetensors import load_flat , move_and_save , save , save_nested
7
-
8
- try :
9
- from tensornvme .async_file_io import AsyncFileWriter
10
- except ModuleNotFoundError :
11
- raise ModuleNotFoundError ("Please install tensornvme to use NVMeOptimizer" )
12
-
13
-
14
- from colossalai .testing import check_state_dict_equal
6
+ from colossalai .testing import check_state_dict_equal , clear_cache_before_run
15
7
from colossalai .utils import get_current_device
8
+ from colossalai .utils .safetensors import load_flat , move_and_save , save , save_nested
16
9
17
10
11
+ @clear_cache_before_run ()
18
12
def test_save_load ():
19
13
with tempfile .TemporaryDirectory () as tempdir :
20
14
optimizer_state_dict = {
@@ -111,17 +105,15 @@ def test_save_load():
111
105
}
112
106
113
107
optimizer_saved_path = f"{ tempdir } /save_optimizer.safetensors"
114
- f_writer = AsyncFileWriter (optimizer_saved_path , n_entries = 191 , backend = "pthread" )
115
- save_nested (f_writer , optimizer_state_dict )
108
+ f_writer = save_nested (optimizer_saved_path , optimizer_state_dict )
116
109
f_writer .sync_before_step ()
117
110
f_writer .synchronize ()
118
111
del f_writer
119
112
load_state_dict = load_flat (optimizer_saved_path )
120
113
check_state_dict_equal (load_state_dict , optimizer_state_dict )
121
114
122
115
optimizer_shard_saved_path = f"{ tempdir } /save_optimizer_shard.safetensors"
123
- f_writer = AsyncFileWriter (optimizer_shard_saved_path , n_entries = 191 , backend = "pthread" )
124
- save_nested (f_writer , optimizer_state_dict ["state" ])
116
+ f_writer = save_nested (optimizer_shard_saved_path , optimizer_state_dict ["state" ])
125
117
f_writer .sync_before_step ()
126
118
f_writer .synchronize ()
127
119
del f_writer
@@ -134,8 +126,7 @@ def test_save_load():
134
126
"module.weight2" : torch .rand ((1024 , 1024 )),
135
127
}
136
128
model_saved_path = f"{ tempdir } /save_model.safetensors"
137
- f_writer = AsyncFileWriter (model_saved_path , n_entries = 191 , backend = "pthread" )
138
- save (f_writer , model_state_dict )
129
+ f_writer = save (model_saved_path , model_state_dict )
139
130
f_writer .sync_before_step ()
140
131
f_writer .synchronize ()
141
132
del f_writer
@@ -145,8 +136,7 @@ def test_save_load():
145
136
model_state_dict_cuda = {k : v .to (get_current_device ()) for k , v in model_state_dict .items ()}
146
137
model_state_pinned = {k : v .pin_memory () for k , v in model_state_dict .items ()}
147
138
model_saved_path = f"{ tempdir } /save_model_cuda.safetensors"
148
- f_writer = AsyncFileWriter (model_saved_path , n_entries = 191 , backend = "pthread" )
149
- move_and_save (f_writer , model_state_dict_cuda , model_state_pinned )
139
+ f_writer = move_and_save (model_saved_path , model_state_dict_cuda , model_state_pinned )
150
140
f_writer .sync_before_step ()
151
141
f_writer .synchronize ()
152
142
del f_writer
0 commit comments