@@ -161,6 +161,7 @@ def weights(self):
161
161
162
162
def save (self , filename ):
163
163
state_dict = {}
164
+ optimizer_saved_dict = {}
164
165
165
166
for k , v in self .layers .items ():
166
167
state_dict [k ] = (v [0 ].state_dict (), v [1 ].state_dict ())
@@ -175,9 +176,10 @@ def save(self, filename):
175
176
state_dict ['sd_checkpoint' ] = self .sd_checkpoint
176
177
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
177
178
if self .optimizer_name is not None :
178
- state_dict ['optimizer_name' ] = self .optimizer_name
179
+ optimizer_saved_dict ['optimizer_name' ] = self .optimizer_name
179
180
if self .optimizer_state_dict :
180
- state_dict ['optimizer_state_dict' ] = self .optimizer_state_dict
181
+ optimizer_saved_dict ['optimizer_state_dict' ] = self .optimizer_state_dict
182
+ torch .save (optimizer_saved_dict , filename + '.optim' )
181
183
182
184
torch .save (state_dict , filename )
183
185
@@ -198,9 +200,11 @@ def load(self, filename):
198
200
print (f"Layer norm is set to { self .add_layer_norm } " )
199
201
self .use_dropout = state_dict .get ('use_dropout' , False )
200
202
print (f"Dropout usage is set to { self .use_dropout } " )
201
- self .optimizer_name = state_dict .get ('optimizer_name' , 'AdamW' )
203
+
204
+ optimizer_saved_dict = torch .load (self .filename + '.optim' , map_location = 'cpu' ) if os .path .exists (self .filename + '.optim' ) else {}
205
+ self .optimizer_name = optimizer_saved_dict .get ('optimizer_name' , 'AdamW' )
202
206
print (f"Optimizer name is { self .optimizer_name } " )
203
- self .optimizer_state_dict = state_dict .get ('optimizer_state_dict' , None )
207
+ self .optimizer_state_dict = optimizer_saved_dict .get ('optimizer_state_dict' , None )
204
208
if self .optimizer_state_dict :
205
209
print ("Loaded existing optimizer from checkpoint" )
206
210
else :
0 commit comments