Skip to content

Commit 9c9b4b0

Browse files
committed
better
1 parent a890cbe commit 9c9b4b0

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

examples/research_projects/lpl/lpl_loss.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,19 @@ def __init__(
5757
self.decoder = self.vae.decoder
5858
# Store scaling factors as tensors on the correct device
5959
device = next(self.vae.parameters()).device
60-
self.scale = torch.tensor(getattr(self.vae.config, "scaling_factor", 1.0), device=device)
61-
self.shift = torch.tensor(getattr(self.vae.config, "shift_factor", 0.0), device=device)
60+
61+
# Get scaling factors with proper defaults and handle None values
62+
scale_factor = getattr(self.vae.config, "scaling_factor", None)
63+
shift_factor = getattr(self.vae.config, "shift_factor", None)
64+
65+
# Convert to tensors with proper defaults
66+
self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)
67+
self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)
6268

6369
# Debug logging
6470
print(f"VAE config: {self.vae.config}")
65-
print(f"Initial scale: {self.scale}, shift: {self.shift}")
71+
print(f"Raw scale factor: {scale_factor}, shift factor: {shift_factor}")
72+
print(f"Final scale tensor: {self.scale}, shift tensor: {self.shift}")
6673
print(f"Device: {device}")
6774

6875
self.gradient_checkpointing = grad_ckpt

0 commit comments

Comments
 (0)