@@ -119,7 +119,7 @@ def process_file(path, filename):
119
119
vec = emb .detach ().to (devices .device , dtype = torch .float32 )
120
120
embedding = Embedding (vec , name )
121
121
embedding .step = data .get ('step' , None )
122
- embedding .sd_checkpoint = data .get ('hash ' , None )
122
+ embedding .sd_checkpoint = data .get ('sd_checkpoint ' , None )
123
123
embedding .sd_checkpoint_name = data .get ('sd_checkpoint_name' , None )
124
124
self .register_embedding (embedding , shared .sd_model )
125
125
@@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
259
259
hijack = sd_hijack .model_hijack
260
260
261
261
embedding = hijack .embedding_db .word_embeddings [embedding_name ]
262
+ checkpoint = sd_models .select_checkpoint ()
262
263
263
264
ititial_step = embedding .step or 0
264
265
if ititial_step > steps :
@@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
314
315
315
316
if embedding_dir is not None and steps_done % save_embedding_every == 0 :
316
317
# Before saving, change name to match current checkpoint.
317
- embedding . name = f'{ embedding_name } -{ steps_done } '
318
- last_saved_file = os .path .join (embedding_dir , f'{ embedding . name } .pt' )
319
- embedding . save ( last_saved_file )
318
+ embedding_name_every = f'{ embedding_name } -{ steps_done } '
319
+ last_saved_file = os .path .join (embedding_dir , f'{ embedding_name_every } .pt' )
320
+ save_embedding ( embedding , checkpoint , embedding_name_every , last_saved_file , remove_cached_checksum = True )
320
321
embedding_yet_to_be_embedded = True
321
322
322
323
write_loss (log_directory , "textual_inversion_loss.csv" , embedding .step , len (ds ), {
@@ -397,14 +398,26 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
397
398
</p>
398
399
"""
399
400
400
- checkpoint = sd_models .select_checkpoint ()
401
-
402
- embedding .sd_checkpoint = checkpoint .hash
403
- embedding .sd_checkpoint_name = checkpoint .model_name
404
- embedding .cached_checksum = None
405
- # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
406
- embedding .name = embedding_name
407
- filename = os .path .join (shared .cmd_opts .embeddings_dir , f'{ embedding .name } .pt' )
408
- embedding .save (filename )
401
+ filename = os .path .join (shared .cmd_opts .embeddings_dir , f'{ embedding_name } .pt' )
402
+ save_embedding (embedding , checkpoint , embedding_name , filename , remove_cached_checksum = True )
409
403
410
404
return embedding , filename
405
+
406
+ def save_embedding (embedding , checkpoint , embedding_name , filename , remove_cached_checksum = True ):
407
+ old_embedding_name = embedding .name
408
+ old_sd_checkpoint = embedding .sd_checkpoint if hasattr (embedding , "sd_checkpoint" ) else None
409
+ old_sd_checkpoint_name = embedding .sd_checkpoint_name if hasattr (embedding , "sd_checkpoint_name" ) else None
410
+ old_cached_checksum = embedding .cached_checksum if hasattr (embedding , "cached_checksum" ) else None
411
+ try :
412
+ embedding .sd_checkpoint = checkpoint .hash
413
+ embedding .sd_checkpoint_name = checkpoint .model_name
414
+ if remove_cached_checksum :
415
+ embedding .cached_checksum = None
416
+ embedding .name = embedding_name
417
+ embedding .save (filename )
418
+ except :
419
+ embedding .sd_checkpoint = old_sd_checkpoint
420
+ embedding .sd_checkpoint_name = old_sd_checkpoint_name
421
+ embedding .name = old_embedding_name
422
+ embedding .cached_checksum = old_cached_checksum
423
+ raise
0 commit comments