Skip to content

Commit d444508

Browse files
committed
make possible to instantiate the pipeline without vae and unet
compute n_targets on the go add a consistency check for vae_scale_factor
1 parent 2ba8d4d commit d444508

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

src/diffusers/pipelines/marigold/pipeline_marigold_depth.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ def check_inputs(
209209
output_type: str,
210210
output_uncertainty: bool,
211211
) -> int:
212+
actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
213+
if actual_vae_scale_factor != self.vae_scale_factor:
214+
raise ValueError(
215+
f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})."
216+
)
212217
if num_inference_steps is None:
213218
raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
214219
if num_inference_steps < 1:

src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ def __init__(
196196
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
197197

198198
self.target_properties = target_properties
199-
self.n_targets = self.unet.config.out_channels // self.vae.config.latent_channels
200199
self.default_denoising_steps = default_denoising_steps
201200
self.default_processing_resolution = default_processing_resolution
202201

@@ -219,6 +218,11 @@ def check_inputs(
219218
output_type: str,
220219
output_uncertainty: bool,
221220
) -> int:
221+
actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
222+
if actual_vae_scale_factor != self.vae_scale_factor:
223+
raise ValueError(
224+
f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})."
225+
)
222226
if num_inference_steps is None:
223227
raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
224228
if num_inference_steps < 1:
@@ -310,7 +314,7 @@ def check_inputs(
310314
W, H = new_W, new_H
311315
w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
312316
h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
313-
shape_expected = (num_images * ensemble_size, self.n_targets * self.vae.config.latent_channels, h, w)
317+
shape_expected = (num_images * ensemble_size, self.unet.config.out_channels, h, w)
314318

315319
if latents.shape != shape_expected:
316320
raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
@@ -546,8 +550,9 @@ def __call__(
546550
# 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
547551
# which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
548552
# Model invocation: self.vae.decoder.
553+
n_targets = self.unet.config.out_channels // self.vae.config.latent_channels
549554
pred_latent_for_decoding = pred_latent.reshape(
550-
num_images * ensemble_size * self.n_targets, self.vae.config.latent_channels, *pred_latent.shape[2:]
555+
num_images * ensemble_size * n_targets, self.vae.config.latent_channels, *pred_latent.shape[2:]
551556
) # [N*E*T,4,PPH,PPW]
552557
prediction = torch.cat(
553558
[
@@ -572,7 +577,7 @@ def __call__(
572577
uncertainty = None
573578
if ensemble_size > 1:
574579
prediction = prediction.reshape(
575-
num_images, ensemble_size, self.n_targets, *prediction.shape[1:]
580+
num_images, ensemble_size, n_targets, *prediction.shape[1:]
576581
) # [N,E,T,3,PH,PW]
577582
prediction = [
578583
self.ensemble_intrinsics(prediction[i], output_uncertainty, **(ensembling_kwargs or {}))
@@ -645,8 +650,9 @@ def retrieve_latents(encoder_output):
645650

646651
pred_latent = latents
647652
if pred_latent is None:
653+
n_targets = self.unet.config.out_channels // self.vae.config.latent_channels
648654
pred_latent = randn_tensor(
649-
(N_E, self.n_targets * C, H, W),
655+
(N_E, n_targets * C, H, W),
650656
generator=generator,
651657
device=image_latent.device,
652658
dtype=image_latent.dtype,

src/diffusers/pipelines/marigold/pipeline_marigold_normals.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ def check_inputs(
196196
output_type: str,
197197
output_uncertainty: bool,
198198
) -> int:
199+
actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
200+
if actual_vae_scale_factor != self.vae_scale_factor:
201+
raise ValueError(
202+
f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})."
203+
)
199204
if num_inference_steps is None:
200205
raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
201206
if num_inference_steps < 1:

0 commit comments

Comments
 (0)