Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = self.text_encoder_2(text_input_ids.to(device, non_blocking=True), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -284,11 +284,11 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
prompt_embeds = self.text_encoder(text_input_ids.to(device, non_blocking=True), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device, non_blocking=True)

# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
Expand Down Expand Up @@ -371,7 +371,7 @@ def encode_prompt(
unscale_lora_layers(self.text_encoder_2, lora_scale)

dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = torch.zeros((prompt_embeds.shape[1], 3), device=device, dtype=dtype)

return prompt_embeds, pooled_prompt_embeds, text_ids

Expand Down Expand Up @@ -437,7 +437,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)

return latent_image_ids.to(device=device, dtype=dtype)
return latent_image_ids.to(device=device, dtype=dtype, non_blocking=True)

@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
Expand Down Expand Up @@ -508,7 +508,7 @@ def prepare_latents(

if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
return latents.to(device=device, dtype=dtype, non_blocking=True), latent_image_ids

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand Down Expand Up @@ -710,6 +710,10 @@ def __call__(
sigmas,
mu=mu,
)
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(0)


num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ def set_timesteps(
else:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)

sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device, non_blocking=True)
timesteps = sigmas * self.config.num_train_timesteps

self.timesteps = timesteps.to(device=device)
self.timesteps = timesteps.to(device=device, non_blocking=True)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

self._step_index = None
Expand Down