@@ -196,7 +196,6 @@ def __init__(
196196 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
197197
198198 self .target_properties = target_properties
199- self .n_targets = self .unet .config .out_channels // self .vae .config .latent_channels
200199 self .default_denoising_steps = default_denoising_steps
201200 self .default_processing_resolution = default_processing_resolution
202201
@@ -219,6 +218,11 @@ def check_inputs(
219218 output_type : str ,
220219 output_uncertainty : bool ,
221220 ) -> int :
221+ actual_vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
222+ if actual_vae_scale_factor != self .vae_scale_factor :
223+ raise ValueError (
224+ f"`vae_scale_factor` computed at initialization ({ self .vae_scale_factor } ) differs from the actual one ({ actual_vae_scale_factor } )."
225+ )
222226 if num_inference_steps is None :
223227 raise ValueError ("`num_inference_steps` is not specified and could not be resolved from the model config." )
224228 if num_inference_steps < 1 :
@@ -310,7 +314,7 @@ def check_inputs(
310314 W , H = new_W , new_H
311315 w = (W + self .vae_scale_factor - 1 ) // self .vae_scale_factor
312316 h = (H + self .vae_scale_factor - 1 ) // self .vae_scale_factor
313- shape_expected = (num_images * ensemble_size , self .n_targets * self . vae . config .latent_channels , h , w )
317+ shape_expected = (num_images * ensemble_size , self .unet . config .out_channels , h , w )
314318
315319 if latents .shape != shape_expected :
316320 raise ValueError (f"`latents` has unexpected shape={ latents .shape } expected={ shape_expected } ." )
@@ -546,8 +550,9 @@ def __call__(
546550 # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
547551 # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
548552 # Model invocation: self.vae.decoder.
553+ n_targets = self .unet .config .out_channels // self .vae .config .latent_channels
549554 pred_latent_for_decoding = pred_latent .reshape (
550- num_images * ensemble_size * self . n_targets , self .vae .config .latent_channels , * pred_latent .shape [2 :]
555+ num_images * ensemble_size * n_targets , self .vae .config .latent_channels , * pred_latent .shape [2 :]
551556 ) # [N*E*T,4,PPH,PPW]
552557 prediction = torch .cat (
553558 [
@@ -572,7 +577,7 @@ def __call__(
572577 uncertainty = None
573578 if ensemble_size > 1 :
574579 prediction = prediction .reshape (
575- num_images , ensemble_size , self . n_targets , * prediction .shape [1 :]
580+ num_images , ensemble_size , n_targets , * prediction .shape [1 :]
576581 ) # [N,E,T,3,PH,PW]
577582 prediction = [
578583 self .ensemble_intrinsics (prediction [i ], output_uncertainty , ** (ensembling_kwargs or {}))
@@ -645,8 +650,9 @@ def retrieve_latents(encoder_output):
645650
646651 pred_latent = latents
647652 if pred_latent is None :
653+ n_targets = self .unet .config .out_channels // self .vae .config .latent_channels
648654 pred_latent = randn_tensor (
649- (N_E , self . n_targets * C , H , W ),
655+ (N_E , n_targets * C , H , W ),
650656 generator = generator ,
651657 device = image_latent .device ,
652658 dtype = image_latent .dtype ,
0 commit comments