Skip to content

Commit ab05a74

Browse files
committed
Revert "Add cleanup after training"
This reverts commit 3ce2bfd.
1 parent a27d19d commit ab05a74

File tree

2 files changed

+186
-200
lines changed

2 files changed

+186
-200
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 96 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -398,112 +398,110 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
398398
forced_filename = "<none>"
399399

400400
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
401-
402-
try:
403-
for i, entries in pbar:
404-
hypernetwork.step = i + ititial_step
405-
if len(loss_dict) > 0:
406-
previous_mean_losses = [i[-1] for i in loss_dict.values()]
407-
previous_mean_loss = mean(previous_mean_losses)
408-
409-
scheduler.apply(optimizer, hypernetwork.step)
410-
if scheduler.finished:
411-
break
412-
413-
if shared.state.interrupted:
414-
break
415-
416-
with torch.autocast("cuda"):
417-
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
418-
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
419-
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
420-
loss = shared.sd_model(x, c)[0]
421-
del x
422-
del c
423-
424-
losses[hypernetwork.step % losses.shape[0]] = loss.item()
425-
for entry in entries:
426-
loss_dict[entry.filename].append(loss.item())
427-
428-
optimizer.zero_grad()
429-
weights[0].grad = None
430-
loss.backward()
431-
432-
if weights[0].grad is None:
433-
steps_without_grad += 1
434-
else:
435-
steps_without_grad = 0
436-
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
437-
438-
optimizer.step()
439-
440-
steps_done = hypernetwork.step + 1
441-
442-
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
443-
raise RuntimeError("Loss diverged.")
401+
for i, entries in pbar:
402+
hypernetwork.step = i + ititial_step
403+
if len(loss_dict) > 0:
404+
previous_mean_losses = [i[-1] for i in loss_dict.values()]
405+
previous_mean_loss = mean(previous_mean_losses)
444406

445-
if len(previous_mean_losses) > 1:
446-
std = stdev(previous_mean_losses)
407+
scheduler.apply(optimizer, hypernetwork.step)
408+
if scheduler.finished:
409+
break
410+
411+
if shared.state.interrupted:
412+
break
413+
414+
with torch.autocast("cuda"):
415+
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
416+
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
417+
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
418+
loss = shared.sd_model(x, c)[0]
419+
del x
420+
del c
421+
422+
losses[hypernetwork.step % losses.shape[0]] = loss.item()
423+
for entry in entries:
424+
loss_dict[entry.filename].append(loss.item())
425+
426+
optimizer.zero_grad()
427+
weights[0].grad = None
428+
loss.backward()
429+
430+
if weights[0].grad is None:
431+
steps_without_grad += 1
447432
else:
448-
std = 0
449-
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
450-
pbar.set_description(dataset_loss_info)
451-
452-
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
453-
# Before saving, change name to match current checkpoint.
454-
hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
455-
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
456-
hypernetwork.save(last_saved_file)
457-
458-
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
459-
"loss": f"{previous_mean_loss:.7f}",
460-
"learn_rate": scheduler.learn_rate
461-
})
462-
463-
if images_dir is not None and steps_done % create_image_every == 0:
464-
forced_filename = f'{hypernetwork_name}-{steps_done}'
465-
last_saved_image = os.path.join(images_dir, forced_filename)
466-
467-
optimizer.zero_grad()
468-
shared.sd_model.cond_stage_model.to(devices.device)
469-
shared.sd_model.first_stage_model.to(devices.device)
470-
471-
p = processing.StableDiffusionProcessingTxt2Img(
472-
sd_model=shared.sd_model,
473-
do_not_save_grid=True,
474-
do_not_save_samples=True,
475-
)
433+
steps_without_grad = 0
434+
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
476435

477-
if preview_from_txt2img:
478-
p.prompt = preview_prompt
479-
p.negative_prompt = preview_negative_prompt
480-
p.steps = preview_steps
481-
p.sampler_index = preview_sampler_index
482-
p.cfg_scale = preview_cfg_scale
483-
p.seed = preview_seed
484-
p.width = preview_width
485-
p.height = preview_height
486-
else:
487-
p.prompt = entries[0].cond_text
488-
p.steps = 20
436+
optimizer.step()
489437

490-
preview_text = p.prompt
438+
steps_done = hypernetwork.step + 1
491439

492-
processed = processing.process_images(p)
493-
image = processed.images[0] if len(processed.images)>0 else None
440+
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
441+
raise RuntimeError("Loss diverged.")
442+
443+
if len(previous_mean_losses) > 1:
444+
std = stdev(previous_mean_losses)
445+
else:
446+
std = 0
447+
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
448+
pbar.set_description(dataset_loss_info)
449+
450+
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
451+
# Before saving, change name to match current checkpoint.
452+
hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
453+
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
454+
hypernetwork.save(last_saved_file)
455+
456+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
457+
"loss": f"{previous_mean_loss:.7f}",
458+
"learn_rate": scheduler.learn_rate
459+
})
460+
461+
if images_dir is not None and steps_done % create_image_every == 0:
462+
forced_filename = f'{hypernetwork_name}-{steps_done}'
463+
last_saved_image = os.path.join(images_dir, forced_filename)
464+
465+
optimizer.zero_grad()
466+
shared.sd_model.cond_stage_model.to(devices.device)
467+
shared.sd_model.first_stage_model.to(devices.device)
494468

495-
if unload:
496-
shared.sd_model.cond_stage_model.to(devices.cpu)
497-
shared.sd_model.first_stage_model.to(devices.cpu)
469+
p = processing.StableDiffusionProcessingTxt2Img(
470+
sd_model=shared.sd_model,
471+
do_not_save_grid=True,
472+
do_not_save_samples=True,
473+
)
498474

499-
if image is not None:
500-
shared.state.current_image = image
501-
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
502-
last_saved_image += f", prompt: {preview_text}"
475+
if preview_from_txt2img:
476+
p.prompt = preview_prompt
477+
p.negative_prompt = preview_negative_prompt
478+
p.steps = preview_steps
479+
p.sampler_index = preview_sampler_index
480+
p.cfg_scale = preview_cfg_scale
481+
p.seed = preview_seed
482+
p.width = preview_width
483+
p.height = preview_height
484+
else:
485+
p.prompt = entries[0].cond_text
486+
p.steps = 20
487+
488+
preview_text = p.prompt
489+
490+
processed = processing.process_images(p)
491+
image = processed.images[0] if len(processed.images)>0 else None
492+
493+
if unload:
494+
shared.sd_model.cond_stage_model.to(devices.cpu)
495+
shared.sd_model.first_stage_model.to(devices.cpu)
503496

504-
shared.state.job_no = hypernetwork.step
497+
if image is not None:
498+
shared.state.current_image = image
499+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
500+
last_saved_image += f", prompt: {preview_text}"
505501

506-
shared.state.textinfo = f"""
502+
shared.state.job_no = hypernetwork.step
503+
504+
shared.state.textinfo = f"""
507505
<p>
508506
Loss: {previous_mean_loss:.7f}<br/>
509507
Step: {hypernetwork.step}<br/>
@@ -512,14 +510,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
512510
Last saved image: {html.escape(last_saved_image)}<br/>
513511
</p>
514512
"""
515-
finally:
516-
if weights:
517-
for weight in weights:
518-
weight.requires_grad = False
519-
if unload:
520-
shared.sd_model.cond_stage_model.to(devices.device)
521-
shared.sd_model.first_stage_model.to(devices.device)
522-
513+
523514
report_statistics(loss_dict)
524515
checkpoint = sd_models.select_checkpoint()
525516

0 commit comments

Comments
 (0)