@@ -55,9 +55,16 @@ def __init__(
5555 super ().__init__ ()
5656 self .vae = vae
5757 self .decoder = self .vae .decoder
58- # Use vae.config for scaling and shift factors
59- self .scale = getattr (self .vae .config , "scaling_factor" , 1.0 )
60- self .shift = getattr (self .vae .config , "shift_factor" , 0.0 )
58+ # Store scaling factors as tensors on the correct device
59+ device = next (self .vae .parameters ()).device
60+ self .scale = torch .tensor (getattr (self .vae .config , "scaling_factor" , 1.0 ), device = device )
61+ self .shift = torch .tensor (getattr (self .vae .config , "shift_factor" , 0.0 ), device = device )
62+
63+ # Debug logging
64+ print (f"VAE config: { self .vae .config } " )
65+ print (f"Initial scale: { self .scale } , shift: { self .shift } " )
66+ print (f"Device: { device } " )
67+
6168 self .gradient_checkpointing = grad_ckpt
6269 self .pow_law = pow_law
6370 self .norm_type = norm_type .lower ()
@@ -130,6 +137,11 @@ def custom_forward(*inputs):
130137 return features
131138
132139 def get_loss (self , input , target , get_hist = False ):
140+ # Debug logging for each call
141+ print (f"Current scale: { self .scale } , shift: { self .shift } " )
142+ print (f"Input shape: { input .shape } , dtype: { input .dtype } " )
143+ print (f"Scale type: { type (self .scale )} , shift type: { type (self .shift )} " )
144+
133145 if self .feature_type == "feature" :
134146 inp_f = self .get_features (self .shift + input / self .scale )
135147 tar_f = self .get_features (self .shift + target / self .scale , disable_grads = True )
0 commit comments