Skip to content

Commit b3d1152

Browse files
committed
revert to having n_targets as a pipeline property
1 parent d444508 commit b3d1152

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def __init__(
203203

204204
self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
205205

206+
@property
207+
def n_targets(self):
208+
return self.unet.config.out_channels // self.vae.config.latent_channels
209+
206210
def check_inputs(
207211
self,
208212
image: PipelineImageInput,
@@ -550,9 +554,8 @@ def __call__(
550554
# 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
551555
# which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
552556
# Model invocation: self.vae.decoder.
553-
n_targets = self.unet.config.out_channels // self.vae.config.latent_channels
554557
pred_latent_for_decoding = pred_latent.reshape(
555-
num_images * ensemble_size * n_targets, self.vae.config.latent_channels, *pred_latent.shape[2:]
558+
num_images * ensemble_size * self.n_targets, self.vae.config.latent_channels, *pred_latent.shape[2:]
556559
) # [N*E*T,4,PPH,PPW]
557560
prediction = torch.cat(
558561
[
@@ -577,7 +580,7 @@ def __call__(
577580
uncertainty = None
578581
if ensemble_size > 1:
579582
prediction = prediction.reshape(
580-
num_images, ensemble_size, n_targets, *prediction.shape[1:]
583+
num_images, ensemble_size, self.n_targets, *prediction.shape[1:]
581584
) # [N,E,T,3,PH,PW]
582585
prediction = [
583586
self.ensemble_intrinsics(prediction[i], output_uncertainty, **(ensembling_kwargs or {}))
@@ -650,9 +653,8 @@ def retrieve_latents(encoder_output):
650653

651654
pred_latent = latents
652655
if pred_latent is None:
653-
n_targets = self.unet.config.out_channels // self.vae.config.latent_channels
654656
pred_latent = randn_tensor(
655-
(N_E, n_targets * C, H, W),
657+
(N_E, self.n_targets * C, H, W),
656658
generator=generator,
657659
device=image_latent.device,
658660
dtype=image_latent.dtype,

0 commit comments

Comments
 (0)