Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents


@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -744,7 +748,8 @@ def __call__(
prompt_attention_mask,
negative_prompt_attention_mask,
)

self._interrupt = False

# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down Expand Up @@ -812,6 +817,8 @@ def __call__(

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

Expand Down Expand Up @@ -859,7 +866,7 @@ def __call__(
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
callback(self, step_idx, t, latents) #Not 100% sure if this will break anything. Callback documentation would need to be updated to to reflect the added input

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
Expand Down