@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
204
204
** values ,
205
205
})
206
206
207
+ def validate_train_inputs (model_name , learn_rate , batch_size , data_root , template_file , steps , save_model_every , create_image_every , log_directory , name = "embedding" ):
208
+ assert model_name , f"{ name } not selected"
209
+ assert learn_rate , "Learning rate is empty or 0"
210
+ assert isinstance (batch_size , int ), "Batch size must be integer"
211
+ assert batch_size > 0 , "Batch size must be positive"
212
+ assert data_root , "Dataset directory is empty"
213
+ assert os .path .isdir (data_root ), "Dataset directory doesn't exist"
214
+ assert os .listdir (data_root ), "Dataset directory is empty"
215
+ assert template_file , "Prompt template file is empty"
216
+ assert os .path .isfile (template_file ), "Prompt template file doesn't exist"
217
+ assert steps , "Max steps is empty or 0"
218
+ assert isinstance (steps , int ), "Max steps must be integer"
219
+ assert steps > 0 , "Max steps must be positive"
220
+ assert isinstance (save_model_every , int ), "Save {name} must be integer"
221
+ assert save_model_every >= 0 , "Save {name} must be positive or 0"
222
+ assert isinstance (create_image_every , int ), "Create image must be integer"
223
+ assert create_image_every >= 0 , "Create image must be positive or 0"
224
+ if save_model_every or create_image_every :
225
+ assert log_directory , "Log directory is empty"
207
226
208
227
def train_embedding (embedding_name , learn_rate , batch_size , data_root , log_directory , training_width , training_height , steps , create_image_every , save_embedding_every , template_file , save_image_with_stored_embedding , preview_from_txt2img , preview_prompt , preview_negative_prompt , preview_steps , preview_sampler_index , preview_cfg_scale , preview_seed , preview_width , preview_height ):
209
- assert embedding_name , 'embedding not selected'
228
+ save_embedding_every = save_embedding_every or 0
229
+ create_image_every = create_image_every or 0
230
+ validate_train_inputs (embedding_name , learn_rate , batch_size , data_root , template_file , steps , save_embedding_every , create_image_every , log_directory , name = "embedding" )
210
231
211
232
shared .state .textinfo = "Initializing textual inversion training..."
212
233
shared .state .job_count = steps
@@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
232
253
os .makedirs (images_embeds_dir , exist_ok = True )
233
254
else :
234
255
images_embeds_dir = None
235
-
256
+
236
257
cond_model = shared .sd_model .cond_stage_model
237
258
259
+ hijack = sd_hijack .model_hijack
260
+
261
+ embedding = hijack .embedding_db .word_embeddings [embedding_name ]
262
+
263
+ ititial_step = embedding .step or 0
264
+ if ititial_step > steps :
265
+ shared .state .textinfo = f"Model has already been trained beyond specified max steps"
266
+ return embedding , filename
267
+
268
+ scheduler = LearnRateScheduler (learn_rate , steps , ititial_step )
269
+
270
+ # dataset loading may take a while, so input validations and early returns should be done before this
238
271
shared .state .textinfo = f"Preparing dataset from { html .escape (data_root )} ..."
239
272
with torch .autocast ("cuda" ):
240
273
ds = modules .textual_inversion .dataset .PersonalizedBase (data_root = data_root , width = training_width , height = training_height , repeats = shared .opts .training_image_repeats_per_epoch , placeholder_token = embedding_name , model = shared .sd_model , device = devices .device , template_file = template_file , batch_size = batch_size )
241
274
242
- hijack = sd_hijack .model_hijack
243
-
244
- embedding = hijack .embedding_db .word_embeddings [embedding_name ]
245
275
embedding .vec .requires_grad = True
276
+ optimizer = torch .optim .AdamW ([embedding .vec ], lr = scheduler .learn_rate )
246
277
247
278
losses = torch .zeros ((32 ,))
248
279
@@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
251
282
forced_filename = "<none>"
252
283
embedding_yet_to_be_embedded = False
253
284
254
- ititial_step = embedding .step or 0
255
- if ititial_step > steps :
256
- return embedding , filename
257
-
258
- scheduler = LearnRateScheduler (learn_rate , steps , ititial_step )
259
- optimizer = torch .optim .AdamW ([embedding .vec ], lr = scheduler .learn_rate )
260
-
261
285
pbar = tqdm .tqdm (enumerate (ds ), total = steps - ititial_step )
262
286
for i , entries in pbar :
263
287
embedding .step = i + ititial_step
0 commit comments