Skip to content

Commit 2e348b8

Browse files
committed
quality
1 parent e5f63bf commit 2e348b8

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

examples/research_projects/autoencoderkl/train_autoencoderkl.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@
6161

6262

6363
@torch.no_grad()
64-
def log_validation(
65-
vae, args, accelerator, weight_dtype, step, is_final_validation=False
66-
):
64+
def log_validation(vae, args, accelerator, weight_dtype, step, is_final_validation=False):
6765
logger.info("Running validation... ")
6866

6967
if not is_final_validation:
@@ -91,23 +89,18 @@ def log_validation(
9189
with inference_ctx:
9290
reconstructions = vae(targets).sample
9391

94-
images.append(
95-
torch.cat([targets.cpu(), reconstructions.cpu()], axis=0)
96-
)
92+
images.append(torch.cat([targets.cpu(), reconstructions.cpu()], axis=0))
9793

9894
tracker_key = "test" if is_final_validation else "validation"
9995
for tracker in accelerator.trackers:
10096
if tracker.name == "tensorboard":
10197
np_images = np.stack([np.asarray(img) for img in images])
102-
tracker.writer.add_images(
103-
f"{tracker_key}: Original (left), Reconstruction (right)", np_images, step
104-
)
98+
tracker.writer.add_images(f"{tracker_key}: Original (left), Reconstruction (right)", np_images, step)
10599
elif tracker.name == "wandb":
106100
tracker.log(
107101
{
108102
f"{tracker_key}: Original (left), Reconstruction (right)": [
109-
wandb.Image(torchvision.utils.make_grid(image))
110-
for _, image in enumerate(images)
103+
wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(images)
111104
]
112105
}
113106
)
@@ -677,7 +670,9 @@ def load_model_hook(models, input_dir):
677670

678671
# pop models so that they are not loaded again
679672
model = models.pop()
680-
load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(os.path.join(input_dir, "discriminator", "pytorch_model.bin"))
673+
load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(
674+
os.path.join(input_dir, "discriminator", "pytorch_model.bin")
675+
)
681676
model.load_state_dict(load_model.state_dict())
682677
del load_model
683678

@@ -690,7 +685,6 @@ def load_model_hook(models, input_dir):
690685
accelerator.register_save_state_pre_hook(save_model_hook)
691686
accelerator.register_load_state_pre_hook(load_model_hook)
692687

693-
694688
vae.requires_grad_(True)
695689
if args.decoder_only:
696690
vae.encoder.requires_grad_(False)
@@ -723,9 +717,7 @@ def load_model_hook(models, input_dir):
723717
)
724718

725719
if unwrap_model(vae).dtype != torch.float32:
726-
raise ValueError(
727-
f"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}"
728-
)
720+
raise ValueError(f"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}")
729721

730722
# Enable TF32 for faster training on Ampere GPUs,
731723
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
@@ -802,7 +794,15 @@ def load_model_hook(models, input_dir):
802794
)
803795

804796
# Prepare everything with our `accelerator`.
805-
vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler = accelerator.prepare(
797+
(
798+
vae,
799+
discriminator,
800+
optimizer,
801+
disc_optimizer,
802+
train_dataloader,
803+
lr_scheduler,
804+
disc_lr_scheduler,
805+
) = accelerator.prepare(
806806
vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler
807807
)
808808

@@ -935,7 +935,7 @@ def load_model_hook(models, input_dir):
935935
"disc_weight": disc_weight.detach().mean().item(),
936936
"disc_factor": disc_factor,
937937
"g_loss": g_loss.detach().mean().item(),
938-
"lr": lr_scheduler.get_last_lr()[0]
938+
"lr": lr_scheduler.get_last_lr()[0],
939939
}
940940

941941
accelerator.backward(loss)
@@ -956,7 +956,7 @@ def load_model_hook(models, input_dir):
956956
"disc_loss": disc_loss.detach().mean().item(),
957957
"logits_real": logits_real.detach().mean().item(),
958958
"logits_fake": logits_fake.detach().mean().item(),
959-
"disc_lr": disc_lr_scheduler.get_last_lr()[0]
959+
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
960960
}
961961
# Checks if the accelerator has performed an optimization step behind the scenes
962962
if accelerator.sync_gradients:

0 commit comments

Comments
 (0)