diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 3464d88145ea..0ad558fef9d7 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -35,6 +35,7 @@ Available models: | [`LTX Video 2B 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | | [`LTX Video 2B 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` | | [`LTX Video 13B 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) | `torch.bfloat16` | +| [`LTX Video 13B 0.9.7 (distilled)`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors) | `torch.bfloat16` | | [`LTX Video Spatial Upscaler 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) | `torch.bfloat16` | Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository. @@ -47,6 +48,14 @@ For the best results, it is recommended to follow the guidelines mentioned in th - For variants with a timestep-aware VAE (LTXV 0.9.1 and above), it is recommended to set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`. - For variants that support interpolation between multiple conditioning images and videos (LTXV 0.9.5 and above), it is recommended to use similar looking images/videos for the best results. High divergence between the conditionings may lead to abrupt transitions in the generated video. + + + + +The examples below show some recommended generation settings, but note that all features supported in the original [LTX Video repository](https://github.com/Lightricks/LTX-Video) are not supported in `diffusers` yet (for example, Spatio-temporal Guidance and CRF compression for image inputs). This will gradually be supported in the future. For the best possible generation quality, we recommend using the code from the original repository. + + + ## Using LTX Video 13B 0.9.7 LTX Video 0.9.7 comes with a spatial latent upscaler and a 13B parameter transformer. The inference involves generating a low resolution video first, which is very fast, followed by upscaling and refining the generated video. @@ -59,8 +68,8 @@ from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition from diffusers.utils import export_to_video, load_video -pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16) -pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16) +pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-dev", torch_dtype=torch.bfloat16) +pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("Lightricks/ltxv-spatial-upscaler-0.9.7", vae=pipe.vae, torch_dtype=torch.bfloat16) pipe.to("cuda") pipe_upsample.to("cuda") pipe.vae.enable_tiling() @@ -93,6 +102,11 @@ latents = pipe( height=downscaled_height, num_frames=num_frames, num_inference_steps=30, + decode_timestep=0.05, + decode_noise_scale=0.025, + image_cond_noise_scale=0.0, + guidance_scale=5.0, + guidance_rescale=0.7, generator=torch.Generator().manual_seed(0), output_type="latent", ).frames @@ -117,7 +131,97 @@ video = pipe( num_inference_steps=10, latents=upscaled_latents, decode_timestep=0.05, - image_cond_noise_scale=0.025, + decode_noise_scale=0.025, + image_cond_noise_scale=0.0, + guidance_scale=5.0, + guidance_rescale=0.7, + generator=torch.Generator().manual_seed(0), + output_type="pil", +).frames[0] + +# Part 4. Downscale the video to the expected resolution +video = [frame.resize((expected_width, expected_height)) for frame in video] + +export_to_video(video, "output.mp4", fps=24) +``` + +## Using LTX Video 0.9.7 (distilled) + +The same example as above can be used with the exception of the `guidance_scale` parameter. The model is both guidance and timestep distilled in order to speedup generation. It requires `guidance_scale` to be set to `1.0`. Additionally, to benefit from the timestep distillation, `num_inference_steps` can be set between `4` and `10` for good generation quality. + +Additionally, custom timesteps can also be used for conditioning the generation. The authors recommend using the following timesteps for best results: +- Base model inference to prepare for upscaling: `[1000, 993, 987, 981, 975, 909, 725, 0.03]` +- Upscaling: `[1000, 909, 725, 421, 0]` + +
+ Full example + +```python +import torch +from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline +from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition +from diffusers.utils import export_to_video, load_video + +pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-distilled", torch_dtype=torch.bfloat16) +pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("Lightricks/ltxv-spatial-upscaler-0.9.7", vae=pipe.vae, torch_dtype=torch.bfloat16) +pipe.to("cuda") +pipe_upsample.to("cuda") +pipe.vae.enable_tiling() + +def round_to_nearest_resolution_acceptable_by_vae(height, width): + height = height - (height % pipe.vae_temporal_compression_ratio) + width = width - (width % pipe.vae_temporal_compression_ratio) + return height, width + +prompt = "artistic anatomical 3d render, utlra quality, human half full male body with transparent skin revealing structure instead of organs, muscular, intricate creative patterns, monochromatic with backlighting, lightning mesh, scientific concept art, blending biology with botany, surreal and ethereal quality, unreal engine 5, ray tracing, ultra realistic, 16K UHD, rich details. camera zooms out in a rotating fashion" +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" +expected_height, expected_width = 768, 1152 +downscale_factor = 2 / 3 +num_frames = 161 + +# Part 1. Generate video at smaller resolution +downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor) +downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width) +latents = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=downscaled_width, + height=downscaled_height, + num_frames=num_frames, + timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03], + decode_timestep=0.05, + decode_noise_scale=0.025, + image_cond_noise_scale=0.0, + guidance_scale=1.0, + guidance_rescale=0.7, + generator=torch.Generator().manual_seed(0), + output_type="latent", +).frames + +# Part 2. Upscale generated video using latent upsampler with fewer inference steps +# The available latent upsampler upscales the height/width by 2x +upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2 +upscaled_latents = pipe_upsample( + latents=latents, + adain_factor=1.0, + output_type="latent" +).frames + +# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended) +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=upscaled_width, + height=upscaled_height, + num_frames=num_frames, + denoise_strength=0.999, # Effectively, 4 inference steps out of 5 + timesteps=[1000, 909, 725, 421, 0], + latents=upscaled_latents, + decode_timestep=0.05, + decode_noise_scale=0.025, + image_cond_noise_scale=0.0, + guidance_scale=1.0, + guidance_rescale=0.7, generator=torch.Generator().manual_seed(0), output_type="pil", ).frames[0] @@ -128,6 +232,8 @@ video = [frame.resize((expected_width, expected_height)) for frame in video] export_to_video(video, "output.mp4", fps=24) ``` +
+ ## Loading Single Files Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format. diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 606d146fd13b..7f669ef50ecf 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -140,6 +140,33 @@ def retrieve_timesteps( return timesteps, num_inference_steps +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -481,6 +508,10 @@ def prepare_latents( def guidance_scale(self): return self._guidance_scale + @property + def guidance_rescale(self): + return self._guidance_rescale + @property def do_classifier_free_guidance(self): return self._guidance_scale > 1.0 @@ -514,6 +545,7 @@ def __call__( num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 3, + guidance_rescale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -556,6 +588,11 @@ def __call__( of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -624,6 +661,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -737,6 +775,12 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + if self.guidance_rescale > 0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index a99294a9e22f..472488065831 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -222,6 +222,33 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): r""" Pipeline for text/image/video-to-video generation. @@ -794,6 +821,10 @@ def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength): def guidance_scale(self): return self._guidance_scale + @property + def guidance_rescale(self): + return self._guidance_rescale + @property def do_classifier_free_guidance(self): return self._guidance_scale > 1.0 @@ -833,6 +864,7 @@ def __call__( num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 3, + guidance_rescale: float = 0.0, image_cond_noise_scale: float = 0.15, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -893,6 +925,11 @@ def __call__( of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -967,6 +1004,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -1063,9 +1101,11 @@ def __call__( latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio - sigmas = linear_quadratic_schedule(num_inference_steps) - timesteps = sigmas * 1000 + if timesteps is None: + sigmas = linear_quadratic_schedule(num_inference_steps) + timesteps = sigmas * 1000 timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + sigmas = self.scheduler.sigmas num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) latent_sigma = None @@ -1152,6 +1192,12 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) timestep, _ = timestep.chunk(2) + if self.guidance_rescale > 0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + denoised_latents = self.scheduler.step( -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False )[0] diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 416294955918..94bb63b4a4ba 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -159,6 +159,33 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): r""" Pipeline for image-to-video generation. @@ -542,6 +569,10 @@ def prepare_latents( def guidance_scale(self): return self._guidance_scale + @property + def guidance_rescale(self): + return self._guidance_rescale + @property def do_classifier_free_guidance(self): return self._guidance_scale > 1.0 @@ -576,6 +607,7 @@ def __call__( num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 3, + guidance_rescale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -620,6 +652,11 @@ def __call__( of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -688,6 +725,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -811,6 +849,12 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) timestep, _ = timestep.chunk(2) + if self.guidance_rescale > 0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + # compute the previous noisy sample x_t -> x_t-1 noise_pred = self._unpack_latents( noise_pred, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py index a90e52e71709..49cf94e25d30 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py @@ -91,6 +91,34 @@ def prepare_latents( init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) return init_latents + def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent + tensor. + + Args: + latent (`torch.Tensor`): + Input latents to normalize + reference_latents (`torch.Tensor`): + The reference latents providing style statistics. + factor (`float`): + Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + @staticmethod # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents def _normalize_latents( @@ -160,6 +188,7 @@ def __call__( latents: Optional[torch.Tensor] = None, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, + adain_factor: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -204,7 +233,12 @@ def __call__( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) latents = latents.to(self.latent_upsampler.dtype) - latents = self.latent_upsampler(latents) + latents_upsampled = self.latent_upsampler(latents) + + if adain_factor > 0.0: + latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled if output_type == "latent": latents = self._normalize_latents(