Skip to content

Commit ab27c11

Browse files
committed
Add input validations before loading dataset for training
1 parent 35c45df commit ab27c11

File tree

2 files changed

+58
-28
lines changed

2 files changed

+58
-28
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
332332
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
333333
from modules import images
334334

335-
assert hypernetwork_name, 'hypernetwork not selected'
335+
save_hypernetwork_every = save_hypernetwork_every or 0
336+
create_image_every = create_image_every or 0
337+
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
336338

337339
path = shared.hypernetworks.get(hypernetwork_name, None)
338340
shared.loaded_hypernetwork = Hypernetwork()
@@ -358,39 +360,43 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
358360
else:
359361
images_dir = None
360362

363+
hypernetwork = shared.loaded_hypernetwork
364+
365+
ititial_step = hypernetwork.step or 0
366+
if ititial_step > steps:
367+
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
368+
return hypernetwork, filename
369+
370+
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
371+
372+
# dataset loading may take a while, so input validations and early returns should be done before this
361373
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
362374
with torch.autocast("cuda"):
363375
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
376+
364377
if unload:
365378
shared.sd_model.cond_stage_model.to(devices.cpu)
366379
shared.sd_model.first_stage_model.to(devices.cpu)
367380

368-
hypernetwork = shared.loaded_hypernetwork
369-
weights = hypernetwork.weights()
370-
for weight in weights:
371-
weight.requires_grad = True
372-
373381
size = len(ds.indexes)
374382
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
375383
losses = torch.zeros((size,))
376384
previous_mean_losses = [0]
377385
previous_mean_loss = 0
378386
print("Mean loss of {} elements".format(size))
379-
380-
last_saved_file = "<none>"
381-
last_saved_image = "<none>"
382-
forced_filename = "<none>"
383-
384-
ititial_step = hypernetwork.step or 0
385-
if ititial_step > steps:
386-
return hypernetwork, filename
387-
388-
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
387+
388+
weights = hypernetwork.weights()
389+
for weight in weights:
390+
weight.requires_grad = True
389391
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
390392
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
391393

392394
steps_without_grad = 0
393395

396+
last_saved_file = "<none>"
397+
last_saved_image = "<none>"
398+
forced_filename = "<none>"
399+
394400
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
395401
for i, entries in pbar:
396402
hypernetwork.step = i + ititial_step

modules/textual_inversion/textual_inversion.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
204204
**values,
205205
})
206206

207+
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
208+
assert model_name, f"{name} not selected"
209+
assert learn_rate, "Learning rate is empty or 0"
210+
assert isinstance(batch_size, int), "Batch size must be integer"
211+
assert batch_size > 0, "Batch size must be positive"
212+
assert data_root, "Dataset directory is empty"
213+
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
214+
assert os.listdir(data_root), "Dataset directory is empty"
215+
assert template_file, "Prompt template file is empty"
216+
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
217+
assert steps, "Max steps is empty or 0"
218+
assert isinstance(steps, int), "Max steps must be integer"
219+
assert steps > 0 , "Max steps must be positive"
220+
assert isinstance(save_model_every, int), "Save {name} must be integer"
221+
assert save_model_every >= 0 , "Save {name} must be positive or 0"
222+
assert isinstance(create_image_every, int), "Create image must be integer"
223+
assert create_image_every >= 0 , "Create image must be positive or 0"
224+
if save_model_every or create_image_every:
225+
assert log_directory, "Log directory is empty"
207226

208227
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
209-
assert embedding_name, 'embedding not selected'
228+
save_embedding_every = save_embedding_every or 0
229+
create_image_every = create_image_every or 0
230+
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
210231

211232
shared.state.textinfo = "Initializing textual inversion training..."
212233
shared.state.job_count = steps
@@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
232253
os.makedirs(images_embeds_dir, exist_ok=True)
233254
else:
234255
images_embeds_dir = None
235-
256+
236257
cond_model = shared.sd_model.cond_stage_model
237258

259+
hijack = sd_hijack.model_hijack
260+
261+
embedding = hijack.embedding_db.word_embeddings[embedding_name]
262+
263+
ititial_step = embedding.step or 0
264+
if ititial_step > steps:
265+
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
266+
return embedding, filename
267+
268+
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
269+
270+
# dataset loading may take a while, so input validations and early returns should be done before this
238271
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
239272
with torch.autocast("cuda"):
240273
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
241274

242-
hijack = sd_hijack.model_hijack
243-
244-
embedding = hijack.embedding_db.word_embeddings[embedding_name]
245275
embedding.vec.requires_grad = True
276+
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
246277

247278
losses = torch.zeros((32,))
248279

@@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
251282
forced_filename = "<none>"
252283
embedding_yet_to_be_embedded = False
253284

254-
ititial_step = embedding.step or 0
255-
if ititial_step > steps:
256-
return embedding, filename
257-
258-
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
259-
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
260-
261285
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
262286
for i, entries in pbar:
263287
embedding.step = i + ititial_step

0 commit comments

Comments
 (0)