Skip to content

Commit a890cbe

Browse files
committed
debug
1 parent 09eb347 commit a890cbe

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

examples/research_projects/lpl/lpl_loss.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,16 @@ def __init__(
5555
super().__init__()
5656
self.vae = vae
5757
self.decoder = self.vae.decoder
58-
# Use vae.config for scaling and shift factors
59-
self.scale = getattr(self.vae.config, "scaling_factor", 1.0)
60-
self.shift = getattr(self.vae.config, "shift_factor", 0.0)
58+
# Store scaling factors as tensors on the correct device
59+
device = next(self.vae.parameters()).device
60+
self.scale = torch.tensor(getattr(self.vae.config, "scaling_factor", 1.0), device=device)
61+
self.shift = torch.tensor(getattr(self.vae.config, "shift_factor", 0.0), device=device)
62+
63+
# Debug logging
64+
print(f"VAE config: {self.vae.config}")
65+
print(f"Initial scale: {self.scale}, shift: {self.shift}")
66+
print(f"Device: {device}")
67+
6168
self.gradient_checkpointing = grad_ckpt
6269
self.pow_law = pow_law
6370
self.norm_type = norm_type.lower()
@@ -130,6 +137,11 @@ def custom_forward(*inputs):
130137
return features
131138

132139
def get_loss(self, input, target, get_hist=False):
140+
# Debug logging for each call
141+
print(f"Current scale: {self.scale}, shift: {self.shift}")
142+
print(f"Input shape: {input.shape}, dtype: {input.dtype}")
143+
print(f"Scale type: {type(self.scale)}, shift type: {type(self.shift)}")
144+
133145
if self.feature_type == "feature":
134146
inp_f = self.get_features(self.shift + input / self.scale)
135147
tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)

0 commit comments

Comments
 (0)