diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 040d935f1b88..96118afc3c32 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -239,7 +239,7 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + prompt_embeds = self.text_encoder_2(text_input_ids.to(device, non_blocking=True), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -284,11 +284,11 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + prompt_embeds = self.text_encoder(text_input_ids.to(device, non_blocking=True), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device, non_blocking=True) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) @@ -371,7 +371,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros((prompt_embeds.shape[1], 3), device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -437,7 +437,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids.to(device=device, dtype=dtype, non_blocking=True) @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -508,7 +508,7 @@ def prepare_latents( if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + return latents.to(device=device, dtype=dtype, non_blocking=True), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -710,6 +710,10 @@ def __call__( sigmas, mu=mu, ) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 937cae2e47f5..85657ed5ebff 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -201,10 +201,10 @@ def set_timesteps( else: sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device, non_blocking=True) timesteps = sigmas * self.config.num_train_timesteps - self.timesteps = timesteps.to(device=device) + self.timesteps = timesteps.to(device=device, non_blocking=True) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None