Skip to content

Commit 3ce2bfd

Browse files
committed
Add cleanup after training
1 parent ab27c11 commit 3ce2bfd

File tree

2 files changed

+200
-186
lines changed

2 files changed

+200
-186
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 105 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -398,110 +398,112 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
398398
forced_filename = "<none>"
399399

400400
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
401-
for i, entries in pbar:
402-
hypernetwork.step = i + ititial_step
403-
if len(loss_dict) > 0:
404-
previous_mean_losses = [i[-1] for i in loss_dict.values()]
405-
previous_mean_loss = mean(previous_mean_losses)
406-
407-
scheduler.apply(optimizer, hypernetwork.step)
408-
if scheduler.finished:
409-
break
410-
411-
if shared.state.interrupted:
412-
break
413-
414-
with torch.autocast("cuda"):
415-
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
416-
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
417-
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
418-
loss = shared.sd_model(x, c)[0]
419-
del x
420-
del c
421-
422-
losses[hypernetwork.step % losses.shape[0]] = loss.item()
423-
for entry in entries:
424-
loss_dict[entry.filename].append(loss.item())
425-
426-
optimizer.zero_grad()
427-
weights[0].grad = None
428-
loss.backward()
429401

430-
if weights[0].grad is None:
431-
steps_without_grad += 1
402+
try:
403+
for i, entries in pbar:
404+
hypernetwork.step = i + ititial_step
405+
if len(loss_dict) > 0:
406+
previous_mean_losses = [i[-1] for i in loss_dict.values()]
407+
previous_mean_loss = mean(previous_mean_losses)
408+
409+
scheduler.apply(optimizer, hypernetwork.step)
410+
if scheduler.finished:
411+
break
412+
413+
if shared.state.interrupted:
414+
break
415+
416+
with torch.autocast("cuda"):
417+
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
418+
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
419+
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
420+
loss = shared.sd_model(x, c)[0]
421+
del x
422+
del c
423+
424+
losses[hypernetwork.step % losses.shape[0]] = loss.item()
425+
for entry in entries:
426+
loss_dict[entry.filename].append(loss.item())
427+
428+
optimizer.zero_grad()
429+
weights[0].grad = None
430+
loss.backward()
431+
432+
if weights[0].grad is None:
433+
steps_without_grad += 1
434+
else:
435+
steps_without_grad = 0
436+
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
437+
438+
optimizer.step()
439+
440+
steps_done = hypernetwork.step + 1
441+
442+
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
443+
raise RuntimeError("Loss diverged.")
444+
445+
if len(previous_mean_losses) > 1:
446+
std = stdev(previous_mean_losses)
432447
else:
433-
steps_without_grad = 0
434-
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
435-
436-
optimizer.step()
437-
438-
steps_done = hypernetwork.step + 1
439-
440-
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
441-
raise RuntimeError("Loss diverged.")
442-
443-
if len(previous_mean_losses) > 1:
444-
std = stdev(previous_mean_losses)
445-
else:
446-
std = 0
447-
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
448-
pbar.set_description(dataset_loss_info)
449-
450-
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
451-
# Before saving, change name to match current checkpoint.
452-
hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
453-
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
454-
hypernetwork.save(last_saved_file)
455-
456-
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
457-
"loss": f"{previous_mean_loss:.7f}",
458-
"learn_rate": scheduler.learn_rate
459-
})
460-
461-
if images_dir is not None and steps_done % create_image_every == 0:
462-
forced_filename = f'{hypernetwork_name}-{steps_done}'
463-
last_saved_image = os.path.join(images_dir, forced_filename)
464-
465-
optimizer.zero_grad()
466-
shared.sd_model.cond_stage_model.to(devices.device)
467-
shared.sd_model.first_stage_model.to(devices.device)
468-
469-
p = processing.StableDiffusionProcessingTxt2Img(
470-
sd_model=shared.sd_model,
471-
do_not_save_grid=True,
472-
do_not_save_samples=True,
473-
)
448+
std = 0
449+
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
450+
pbar.set_description(dataset_loss_info)
451+
452+
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
453+
# Before saving, change name to match current checkpoint.
454+
hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
455+
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
456+
hypernetwork.save(last_saved_file)
457+
458+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
459+
"loss": f"{previous_mean_loss:.7f}",
460+
"learn_rate": scheduler.learn_rate
461+
})
462+
463+
if images_dir is not None and steps_done % create_image_every == 0:
464+
forced_filename = f'{hypernetwork_name}-{steps_done}'
465+
last_saved_image = os.path.join(images_dir, forced_filename)
466+
467+
optimizer.zero_grad()
468+
shared.sd_model.cond_stage_model.to(devices.device)
469+
shared.sd_model.first_stage_model.to(devices.device)
470+
471+
p = processing.StableDiffusionProcessingTxt2Img(
472+
sd_model=shared.sd_model,
473+
do_not_save_grid=True,
474+
do_not_save_samples=True,
475+
)
474476

475-
if preview_from_txt2img:
476-
p.prompt = preview_prompt
477-
p.negative_prompt = preview_negative_prompt
478-
p.steps = preview_steps
479-
p.sampler_index = preview_sampler_index
480-
p.cfg_scale = preview_cfg_scale
481-
p.seed = preview_seed
482-
p.width = preview_width
483-
p.height = preview_height
484-
else:
485-
p.prompt = entries[0].cond_text
486-
p.steps = 20
477+
if preview_from_txt2img:
478+
p.prompt = preview_prompt
479+
p.negative_prompt = preview_negative_prompt
480+
p.steps = preview_steps
481+
p.sampler_index = preview_sampler_index
482+
p.cfg_scale = preview_cfg_scale
483+
p.seed = preview_seed
484+
p.width = preview_width
485+
p.height = preview_height
486+
else:
487+
p.prompt = entries[0].cond_text
488+
p.steps = 20
487489

488-
preview_text = p.prompt
490+
preview_text = p.prompt
489491

490-
processed = processing.process_images(p)
491-
image = processed.images[0] if len(processed.images)>0 else None
492+
processed = processing.process_images(p)
493+
image = processed.images[0] if len(processed.images)>0 else None
492494

493-
if unload:
494-
shared.sd_model.cond_stage_model.to(devices.cpu)
495-
shared.sd_model.first_stage_model.to(devices.cpu)
495+
if unload:
496+
shared.sd_model.cond_stage_model.to(devices.cpu)
497+
shared.sd_model.first_stage_model.to(devices.cpu)
496498

497-
if image is not None:
498-
shared.state.current_image = image
499-
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
500-
last_saved_image += f", prompt: {preview_text}"
499+
if image is not None:
500+
shared.state.current_image = image
501+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
502+
last_saved_image += f", prompt: {preview_text}"
501503

502-
shared.state.job_no = hypernetwork.step
504+
shared.state.job_no = hypernetwork.step
503505

504-
shared.state.textinfo = f"""
506+
shared.state.textinfo = f"""
505507
<p>
506508
Loss: {previous_mean_loss:.7f}<br/>
507509
Step: {hypernetwork.step}<br/>
@@ -510,7 +512,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
510512
Last saved image: {html.escape(last_saved_image)}<br/>
511513
</p>
512514
"""
513-
515+
finally:
516+
if weights:
517+
for weight in weights:
518+
weight.requires_grad = False
519+
if unload:
520+
shared.sd_model.cond_stage_model.to(devices.device)
521+
shared.sd_model.first_stage_model.to(devices.device)
522+
514523
report_statistics(loss_dict)
515524
checkpoint = sd_models.select_checkpoint()
516525

0 commit comments

Comments
 (0)