@@ -119,7 +119,7 @@ def process_file(path, filename):
119
119
vec = emb .detach ().to (devices .device , dtype = torch .float32 )
120
120
embedding = Embedding (vec , name )
121
121
embedding .step = data .get ('step' , None )
122
- embedding .sd_checkpoint = data .get ('hash ' , None )
122
+ embedding .sd_checkpoint = data .get ('sd_checkpoint ' , None )
123
123
embedding .sd_checkpoint_name = data .get ('sd_checkpoint_name' , None )
124
124
self .register_embedding (embedding , shared .sd_model )
125
125
@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
204
204
** values ,
205
205
})
206
206
207
+ def validate_train_inputs (model_name , learn_rate , batch_size , data_root , template_file , steps , save_model_every , create_image_every , log_directory , name = "embedding" ):
208
+ assert model_name , f"{ name } not selected"
209
+ assert learn_rate , "Learning rate is empty or 0"
210
+ assert isinstance (batch_size , int ), "Batch size must be integer"
211
+ assert batch_size > 0 , "Batch size must be positive"
212
+ assert data_root , "Dataset directory is empty"
213
+ assert os .path .isdir (data_root ), "Dataset directory doesn't exist"
214
+ assert os .listdir (data_root ), "Dataset directory is empty"
215
+ assert template_file , "Prompt template file is empty"
216
+ assert os .path .isfile (template_file ), "Prompt template file doesn't exist"
217
+ assert steps , "Max steps is empty or 0"
218
+ assert isinstance (steps , int ), "Max steps must be integer"
219
+ assert steps > 0 , "Max steps must be positive"
220
+ assert isinstance (save_model_every , int ), "Save {name} must be integer"
221
+ assert save_model_every >= 0 , "Save {name} must be positive or 0"
222
+ assert isinstance (create_image_every , int ), "Create image must be integer"
223
+ assert create_image_every >= 0 , "Create image must be positive or 0"
224
+ if save_model_every or create_image_every :
225
+ assert log_directory , "Log directory is empty"
207
226
208
227
def train_embedding (embedding_name , learn_rate , batch_size , data_root , log_directory , training_width , training_height , steps , create_image_every , save_embedding_every , template_file , save_image_with_stored_embedding , preview_from_txt2img , preview_prompt , preview_negative_prompt , preview_steps , preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ):
209
- assert embedding_name , 'embedding not selected'
228
+ save_embedding_every = save_embedding_every or 0
229
+ create_image_every = create_image_every or 0
230
+ validate_train_inputs (embedding_name , learn_rate , batch_size , data_root , template_file , steps , save_embedding_every , create_image_every , log_directory , name = "embedding" )
210
231
211
232
shared .state .textinfo = "Initializing textual inversion training..."
212
233
shared .state .job_count = steps
@@ -232,17 +253,28 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
232
253
os .makedirs (images_embeds_dir , exist_ok = True )
233
254
else :
234
255
images_embeds_dir = None
235
-
256
+
236
257
cond_model = shared .sd_model .cond_stage_model
237
258
259
+ hijack = sd_hijack .model_hijack
260
+
261
+ embedding = hijack .embedding_db .word_embeddings [embedding_name ]
262
+ checkpoint = sd_models .select_checkpoint ()
263
+
264
+ ititial_step = embedding .step or 0
265
+ if ititial_step >= steps :
266
+ shared .state .textinfo = f"Model has already been trained beyond specified max steps"
267
+ return embedding , filename
268
+
269
+ scheduler = LearnRateScheduler (learn_rate , steps , ititial_step )
270
+
271
+ # dataset loading may take a while, so input validations and early returns should be done before this
238
272
shared .state .textinfo = f"Preparing dataset from { html .escape (data_root )} ..."
239
273
with torch .autocast ("cuda" ):
240
274
ds = modules .textual_inversion .dataset .PersonalizedBase (data_root = data_root , width = training_width , height = training_height , repeats = shared .opts .training_image_repeats_per_epoch , placeholder_token = embedding_name , model = shared .sd_model , device = devices .device , template_file = template_file , batch_size = batch_size )
241
275
242
- hijack = sd_hijack .model_hijack
243
-
244
- embedding = hijack .embedding_db .word_embeddings [embedding_name ]
245
276
embedding .vec .requires_grad = True
277
+ optimizer = torch .optim .AdamW ([embedding .vec ], lr = scheduler .learn_rate )
246
278
247
279
losses = torch .zeros ((32 ,))
248
280
@@ -251,13 +283,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
251
283
forced_filename = "<none>"
252
284
embedding_yet_to_be_embedded = False
253
285
254
- ititial_step = embedding .step or 0
255
- if ititial_step > steps :
256
- return embedding , filename
257
-
258
- scheduler = LearnRateScheduler (learn_rate , steps , ititial_step )
259
- optimizer = torch .optim .AdamW ([embedding .vec ], lr = scheduler .learn_rate )
260
-
261
286
pbar = tqdm .tqdm (enumerate (ds ), total = steps - ititial_step )
262
287
for i , entries in pbar :
263
288
embedding .step = i + ititial_step
@@ -290,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
290
315
291
316
if embedding_dir is not None and steps_done % save_embedding_every == 0 :
292
317
# Before saving, change name to match current checkpoint.
293
- embedding . name = f'{ embedding_name } -{ steps_done } '
294
- last_saved_file = os .path .join (embedding_dir , f'{ embedding . name } .pt' )
295
- embedding . save ( last_saved_file )
318
+ embedding_name_every = f'{ embedding_name } -{ steps_done } '
319
+ last_saved_file = os .path .join (embedding_dir , f'{ embedding_name_every } .pt' )
320
+ save_embedding ( embedding , checkpoint , embedding_name_every , last_saved_file , remove_cached_checksum = True )
296
321
embedding_yet_to_be_embedded = True
297
322
298
323
write_loss (log_directory , "textual_inversion_loss.csv" , embedding .step , len (ds ), {
@@ -373,14 +398,26 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
373
398
</p>
374
399
"""
375
400
376
- checkpoint = sd_models .select_checkpoint ()
377
-
378
- embedding .sd_checkpoint = checkpoint .hash
379
- embedding .sd_checkpoint_name = checkpoint .model_name
380
- embedding .cached_checksum = None
381
- # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
382
- embedding .name = embedding_name
383
- filename = os .path .join (shared .cmd_opts .embeddings_dir , f'{ embedding .name } .pt' )
384
- embedding .save (filename )
401
+ filename = os .path .join (shared .cmd_opts .embeddings_dir , f'{ embedding_name } .pt' )
402
+ save_embedding (embedding , checkpoint , embedding_name , filename , remove_cached_checksum = True )
385
403
386
404
return embedding , filename
405
+
406
+ def save_embedding (embedding , checkpoint , embedding_name , filename , remove_cached_checksum = True ):
407
+ old_embedding_name = embedding .name
408
+ old_sd_checkpoint = embedding .sd_checkpoint if hasattr (embedding , "sd_checkpoint" ) else None
409
+ old_sd_checkpoint_name = embedding .sd_checkpoint_name if hasattr (embedding , "sd_checkpoint_name" ) else None
410
+ old_cached_checksum = embedding .cached_checksum if hasattr (embedding , "cached_checksum" ) else None
411
+ try :
412
+ embedding .sd_checkpoint = checkpoint .hash
413
+ embedding .sd_checkpoint_name = checkpoint .model_name
414
+ if remove_cached_checksum :
415
+ embedding .cached_checksum = None
416
+ embedding .name = embedding_name
417
+ embedding .save (filename )
418
+ except :
419
+ embedding .sd_checkpoint = old_sd_checkpoint
420
+ embedding .sd_checkpoint_name = old_sd_checkpoint_name
421
+ embedding .name = old_embedding_name
422
+ embedding .cached_checksum = old_cached_checksum
423
+ raise
0 commit comments