@@ -203,6 +203,10 @@ def __init__(
203203
204204 self .image_processor = MarigoldImageProcessor (vae_scale_factor = self .vae_scale_factor )
205205
206+ @property
207+ def n_targets (self ):
208+ return self .unet .config .out_channels // self .vae .config .latent_channels
209+
206210 def check_inputs (
207211 self ,
208212 image : PipelineImageInput ,
@@ -550,9 +554,8 @@ def __call__(
550554 # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
551555 # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
552556 # Model invocation: self.vae.decoder.
553- n_targets = self .unet .config .out_channels // self .vae .config .latent_channels
554557 pred_latent_for_decoding = pred_latent .reshape (
555- num_images * ensemble_size * n_targets , self .vae .config .latent_channels , * pred_latent .shape [2 :]
558+ num_images * ensemble_size * self . n_targets , self .vae .config .latent_channels , * pred_latent .shape [2 :]
556559 ) # [N*E*T,4,PPH,PPW]
557560 prediction = torch .cat (
558561 [
@@ -577,7 +580,7 @@ def __call__(
577580 uncertainty = None
578581 if ensemble_size > 1 :
579582 prediction = prediction .reshape (
580- num_images , ensemble_size , n_targets , * prediction .shape [1 :]
583+ num_images , ensemble_size , self . n_targets , * prediction .shape [1 :]
581584 ) # [N,E,T,3,PH,PW]
582585 prediction = [
583586 self .ensemble_intrinsics (prediction [i ], output_uncertainty , ** (ensembling_kwargs or {}))
@@ -650,9 +653,8 @@ def retrieve_latents(encoder_output):
650653
651654 pred_latent = latents
652655 if pred_latent is None :
653- n_targets = self .unet .config .out_channels // self .vae .config .latent_channels
654656 pred_latent = randn_tensor (
655- (N_E , n_targets * C , H , W ),
657+ (N_E , self . n_targets * C , H , W ),
656658 generator = generator ,
657659 device = image_latent .device ,
658660 dtype = image_latent .dtype ,
0 commit comments