From e19e04542fe0e98408e30a759e1741d0a2f26ecf Mon Sep 17 00:00:00 2001 From: hhsparthipan Date: Sun, 9 Feb 2025 15:37:39 +0800 Subject: [PATCH] Add support for ControlNet in Flux controlnet img2img --- ...pipeline_flux_controlnet_image_to_image.py | 364 +++++++++++++----- 1 file changed, 263 insertions(+), 101 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index d8aefc3942e9..734a2fdf7af9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -11,9 +11,16 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import ( + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +) from ...models.autoencoders import AutoencoderKL -from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.controlnets.controlnet_flux import ( + FluxControlNetModel, + FluxMultiControlNetModel, +) from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( @@ -97,7 +104,9 @@ def calculate_shift( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" + encoder_output: torch.Tensor, + generator: Optional[torch.Generator] = None, + sample_mode: str = "sample", ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) @@ -142,9 +151,13 @@ def retrieve_timesteps( second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -154,7 +167,9 @@ def retrieve_timesteps( timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -169,7 +184,9 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxControlNetImg2ImgPipeline( + DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin +): r""" The Flux controlnet pipeline for image-to-image generation. @@ -210,7 +227,10 @@ def __init__( tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, controlnet: Union[ - FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + FluxControlNetModel, + List[FluxControlNetModel], + Tuple[FluxControlNetModel], + FluxMultiControlNetModel, ], ): super().__init__() @@ -227,12 +247,20 @@ def __init__( scheduler=scheduler, controlnet=controlnet, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + if getattr(self, "vae", None) + else 8 + ) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2 + ) self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + self.tokenizer.model_max_length + if hasattr(self, "tokenizer") and self.tokenizer is not None + else 77 ) self.default_sample_size = 128 @@ -264,16 +292,24 @@ def _get_t5_prompt_embeds( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer_2( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer_2.batch_decode( + untruncated_ids[:, self.tokenizer_max_length - 1 : -1] + ) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + prompt_embeds = self.text_encoder_2( + text_input_ids.to(device), output_hidden_states=False + )[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -282,7 +318,9 @@ def _get_t5_prompt_embeds( # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) return prompt_embeds @@ -312,14 +350,22 @@ def _get_clip_prompt_embeds( ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer_max_length - 1 : -1] + ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + prompt_embeds = self.text_encoder( + text_input_ids.to(device), output_hidden_states=False + ) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output @@ -406,7 +452,11 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + dtype = ( + self.text_encoder.dtype + if self.text_encoder is not None + else self.transformer.dtype + ) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -415,14 +465,20 @@ def encode_prompt( def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents( + self.vae.encode(image[i : i + 1]), generator=generator[i] + ) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = retrieve_latents( + self.vae.encode(image), generator=generator + ) - image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + image_latents = ( + image_latents - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor return image_latents @@ -451,15 +507,21 @@ def check_inputs( max_sequence_length=None, ): if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) - if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0: + if ( + height % self.vae_scale_factor * 2 != 0 + or width % self.vae_scale_factor * 2 != 0 + ): logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -479,10 +541,18 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + elif prompt_2 is not None and ( + not isinstance(prompt_2, str) and not isinstance(prompt_2, list) + ): + raise ValueError( + f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" + ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( @@ -490,16 +560,24 @@ def check_inputs( ) if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + raise ValueError( + f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}" + ) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width)[None, :] + ) - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + latent_image_id_height, latent_image_id_width, latent_image_id_channels = ( + latent_image_ids.shape + ) latent_image_ids = latent_image_ids.reshape( latent_image_id_height * latent_image_id_width, latent_image_id_channels @@ -510,9 +588,13 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + latents = latents.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) return latents @@ -558,18 +640,28 @@ def prepare_latents( height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) if latents is not None: return latents.to(device=device, dtype=dtype), latent_image_ids image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + if ( + batch_size > image_latents.shape[0] + and batch_size % image_latents.shape[0] == 0 + ): # expand init_latents for batch_size additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + image_latents = torch.cat( + [image_latents] * additional_image_per_prompt, dim=0 + ) + elif ( + batch_size > image_latents.shape[0] + and batch_size % image_latents.shape[0] != 0 + ): raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) @@ -578,7 +670,9 @@ def prepare_latents( noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents( + latents, batch_size, num_channels_latents, height, width + ) return latents, latent_image_ids # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image @@ -729,12 +823,24 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + if not isinstance(control_guidance_start, list) and isinstance( + control_guidance_end, list + ): + control_guidance_start = len(control_guidance_end) * [ + control_guidance_start + ] + elif not isinstance(control_guidance_end, list) and isinstance( + control_guidance_start, list + ): control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + elif not isinstance(control_guidance_start, list) and not isinstance( + control_guidance_end, list + ): + mult = ( + len(self.controlnet.nets) + if isinstance(self.controlnet, FluxMultiControlNetModel) + else 1 + ) control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], @@ -767,7 +873,9 @@ def __call__( dtype = self.transformer.dtype lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + self.joint_attention_kwargs.get("scale", None) + if self.joint_attention_kwargs is not None + else None ) ( prompt_embeds, @@ -800,18 +908,26 @@ def __call__( dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] - - control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = ( + False if self.controlnet.input_hint_block is None else True ) + if self.controlnet.input_hint_block is None: + control_image = retrieve_latents( + self.vae.encode(control_image), generator=generator + ) + control_image = ( + control_image - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) if control_mode is not None: control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) @@ -820,7 +936,11 @@ def __call__( elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] - for control_image_ in control_image: + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = ( + False if self.controlnet.nets[0].input_hint_block is None else True + ) + for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, @@ -831,20 +951,24 @@ def __call__( dtype=self.vae.dtype, ) height, width = control_image_.shape[-2:] - - control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) - control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - height_control_image, width_control_image = control_image_.shape[2:] - control_image_ = self._pack_latents( - control_image_, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) - - control_images.append(control_image_) + if self.controlnet.nets[0].input_hint_block is None: + control_image_ = retrieve_latents( + self.vae.encode(control_image_), generator=generator + ) + control_image_ = ( + control_image_ - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) control_image = control_images @@ -858,8 +982,14 @@ def __call__( control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) control_mode = control_mode.reshape([-1, 1]) - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * ( + int(width) // self.vae_scale_factor // 2 + ) mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), @@ -874,7 +1004,9 @@ def __call__( sigmas=sigmas, mu=mu, ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device + ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latents, latent_image_ids = self.prepare_latents( @@ -896,9 +1028,13 @@ def __call__( 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + controlnet_keep.append( + keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps + ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -913,36 +1049,53 @@ def __call__( else: use_guidance = self.controlnet.config.guidance_embeds - guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None - guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + guidance = ( + torch.tensor([guidance_scale], device=device) + if use_guidance + else None + ) + guidance = ( + guidance.expand(latents.shape[0]) if guidance is not None else None + ) if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + cond_scale = [ + c * s + for c, s in zip( + controlnet_conditioning_scale, controlnet_keep[i] + ) + ] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - controlnet_block_samples, controlnet_single_block_samples = self.controlnet( - hidden_states=latents, - controlnet_cond=control_image, - controlnet_mode=control_mode, - conditioning_scale=cond_scale, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, + controlnet_block_samples, controlnet_single_block_samples = ( + self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) ) guidance = ( - torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + torch.tensor([guidance_scale], device=device) + if self.transformer.config.guidance_embeds + else None + ) + guidance = ( + guidance.expand(latents.shape[0]) if guidance is not None else None ) - guidance = guidance.expand(latents.shape[0]) if guidance is not None else None noise_pred = self.transformer( hidden_states=latents, @@ -956,10 +1109,13 @@ def __call__( img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): @@ -974,7 +1130,9 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if XLA_AVAILABLE: @@ -983,8 +1141,12 @@ def __call__( if output_type == "latent": image = latents else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor + ) + latents = ( + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type)