@@ -177,12 +177,13 @@ def save(self, filename):
177
177
state_dict ['sd_checkpoint_name' ] = self .sd_checkpoint_name
178
178
if self .optimizer_name is not None :
179
179
optimizer_saved_dict ['optimizer_name' ] = self .optimizer_name
180
+
181
+ torch .save (state_dict , filename )
180
182
if self .optimizer_state_dict :
183
+ optimizer_saved_dict ['hash' ] = sd_models .model_hash (filename )
181
184
optimizer_saved_dict ['optimizer_state_dict' ] = self .optimizer_state_dict
182
185
torch .save (optimizer_saved_dict , filename + '.optim' )
183
186
184
- torch .save (state_dict , filename )
185
-
186
187
def load (self , filename ):
187
188
self .filename = filename
188
189
if self .name is None :
@@ -204,7 +205,10 @@ def load(self, filename):
204
205
optimizer_saved_dict = torch .load (self .filename + '.optim' , map_location = 'cpu' ) if os .path .exists (self .filename + '.optim' ) else {}
205
206
self .optimizer_name = optimizer_saved_dict .get ('optimizer_name' , 'AdamW' )
206
207
print (f"Optimizer name is { self .optimizer_name } " )
207
- self .optimizer_state_dict = optimizer_saved_dict .get ('optimizer_state_dict' , None )
208
+ if sd_models .model_hash (filename ) == optimizer_saved_dict .get ('hash' , None ):
209
+ self .optimizer_state_dict = optimizer_saved_dict .get ('optimizer_state_dict' , None )
210
+ else :
211
+ self .optimizer_state_dict = None
208
212
if self .optimizer_state_dict :
209
213
print ("Loaded existing optimizer from checkpoint" )
210
214
else :
@@ -229,7 +233,7 @@ def list_hypernetworks(path):
229
233
name = os .path .splitext (os .path .basename (filename ))[0 ]
230
234
# Prevent a hypothetical "None.pt" from being listed.
231
235
if name != "None" :
232
- res [name ] = filename
236
+ res [name + f"( { sd_models . model_hash ( filename ) } )" ] = filename
233
237
return res
234
238
235
239
@@ -375,6 +379,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
375
379
else :
376
380
hypernetwork_dir = None
377
381
382
+ hypernetwork_name = hypernetwork_name .rsplit ('(' , 1 )[0 ]
378
383
if create_image_every > 0 :
379
384
images_dir = os .path .join (log_directory , "images" )
380
385
os .makedirs (images_dir , exist_ok = True )
0 commit comments