@@ -57,12 +57,19 @@ def __init__(
5757 self .decoder = self .vae .decoder
5858 # Store scaling factors as tensors on the correct device
5959 device = next (self .vae .parameters ()).device
60- self .scale = torch .tensor (getattr (self .vae .config , "scaling_factor" , 1.0 ), device = device )
61- self .shift = torch .tensor (getattr (self .vae .config , "shift_factor" , 0.0 ), device = device )
60+
61+ # Get scaling factors with proper defaults and handle None values
62+ scale_factor = getattr (self .vae .config , "scaling_factor" , None )
63+ shift_factor = getattr (self .vae .config , "shift_factor" , None )
64+
65+ # Convert to tensors with proper defaults
66+ self .scale = torch .tensor (1.0 if scale_factor is None else scale_factor , device = device )
67+ self .shift = torch .tensor (0.0 if shift_factor is None else shift_factor , device = device )
6268
6369 # Debug logging
6470 print (f"VAE config: { self .vae .config } " )
65- print (f"Initial scale: { self .scale } , shift: { self .shift } " )
71+ print (f"Raw scale factor: { scale_factor } , shift factor: { shift_factor } " )
72+ print (f"Final scale tensor: { self .scale } , shift tensor: { self .shift } " )
6673 print (f"Device: { device } " )
6774
6875 self .gradient_checkpointing = grad_ckpt
0 commit comments