Skip to content

Commit 83b55ba

Browse files
authored
Merge branch 'main' into flux_ptxla_trillium
2 parents fdc592e + edb8c1b commit 83b55ba

31 files changed

+278
-23
lines changed

examples/community/rerender_a_video.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,9 @@ def __call__(
908908
if callback is not None and i % callback_steps == 0:
909909
callback(i, t, latents)
910910

911+
if XLA_AVAILABLE:
912+
xm.mark_step()
913+
911914
if not output_type == "latent":
912915
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
913916
else:

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,9 @@ def __init__(
486486
self.tile_sample_stride_height = 448
487487
self.tile_sample_stride_width = 448
488488

489+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
490+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
491+
489492
def enable_tiling(
490493
self,
491494
tile_sample_min_height: Optional[int] = None,
@@ -515,6 +518,8 @@ def enable_tiling(
515518
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
516519
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
517520
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
521+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
522+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
518523

519524
def disable_tiling(self) -> None:
520525
r"""
@@ -606,11 +611,106 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
606611
return (decoded,)
607612
return DecoderOutput(sample=decoded)
608613

614+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
615+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
616+
for y in range(blend_extent):
617+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
618+
return b
619+
620+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
621+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
622+
for x in range(blend_extent):
623+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
624+
return b
625+
609626
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
610-
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
627+
batch_size, num_channels, height, width = x.shape
628+
latent_height = height // self.spatial_compression_ratio
629+
latent_width = width // self.spatial_compression_ratio
630+
631+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
632+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
633+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
634+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
635+
blend_height = tile_latent_min_height - tile_latent_stride_height
636+
blend_width = tile_latent_min_width - tile_latent_stride_width
637+
638+
# Split x into overlapping tiles and encode them separately.
639+
# The tiles have an overlap to avoid seams between tiles.
640+
rows = []
641+
for i in range(0, x.shape[2], self.tile_sample_stride_height):
642+
row = []
643+
for j in range(0, x.shape[3], self.tile_sample_stride_width):
644+
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
645+
if (
646+
tile.shape[2] % self.spatial_compression_ratio != 0
647+
or tile.shape[3] % self.spatial_compression_ratio != 0
648+
):
649+
pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
650+
pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
651+
tile = F.pad(tile, (0, pad_w, 0, pad_h))
652+
tile = self.encoder(tile)
653+
row.append(tile)
654+
rows.append(row)
655+
result_rows = []
656+
for i, row in enumerate(rows):
657+
result_row = []
658+
for j, tile in enumerate(row):
659+
# blend the above tile and the left tile
660+
# to the current tile and add the current tile to the result row
661+
if i > 0:
662+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
663+
if j > 0:
664+
tile = self.blend_h(row[j - 1], tile, blend_width)
665+
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
666+
result_rows.append(torch.cat(result_row, dim=3))
667+
668+
encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
669+
670+
if not return_dict:
671+
return (encoded,)
672+
return EncoderOutput(latent=encoded)
611673

612674
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
613-
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
675+
batch_size, num_channels, height, width = z.shape
676+
677+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
678+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
679+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
680+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
681+
682+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
683+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
684+
685+
# Split z into overlapping tiles and decode them separately.
686+
# The tiles have an overlap to avoid seams between tiles.
687+
rows = []
688+
for i in range(0, height, tile_latent_stride_height):
689+
row = []
690+
for j in range(0, width, tile_latent_stride_width):
691+
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
692+
decoded = self.decoder(tile)
693+
row.append(decoded)
694+
rows.append(row)
695+
696+
result_rows = []
697+
for i, row in enumerate(rows):
698+
result_row = []
699+
for j, tile in enumerate(row):
700+
# blend the above tile and the left tile
701+
# to the current tile and add the current tile to the result row
702+
if i > 0:
703+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
704+
if j > 0:
705+
tile = self.blend_h(row[j - 1], tile, blend_width)
706+
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
707+
result_rows.append(torch.cat(result_row, dim=3))
708+
709+
decoded = torch.cat(result_rows, dim=2)
710+
711+
if not return_dict:
712+
return (decoded,)
713+
return DecoderOutput(sample=decoded)
614714

615715
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
616716
encoded = self.encode(sample, return_dict=False)[0]

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,9 @@ def encode_prompt(
404404
negative_prompt_2 (`str` or `List[str]`, *optional*):
405405
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
406406
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
407-
negative_prompt_2 (`str` or `List[str]`, *optional*):
407+
negative_prompt_3 (`str` or `List[str]`, *optional*):
408408
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
409-
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
409+
`text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
410410
prompt_embeds (`torch.FloatTensor`, *optional*):
411411
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
412412
provided, text embeddings will be generated from `prompt` input argument.

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,9 @@ def encode_prompt(
410410
negative_prompt_2 (`str` or `List[str]`, *optional*):
411411
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
412412
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
413-
negative_prompt_2 (`str` or `List[str]`, *optional*):
413+
negative_prompt_3 (`str` or `List[str]`, *optional*):
414414
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
415-
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
415+
`text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
416416
prompt_embeds (`torch.FloatTensor`, *optional*):
417417
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
418418
provided, text embeddings will be generated from `prompt` input argument.

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,16 @@ def __call__(
665665
instead.
666666
prompt_2 (`str` or `List[str]`, *optional*):
667667
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
668-
will be used instead
668+
will be used instead.
669+
negative_prompt (`str` or `List[str]`, *optional*):
670+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
671+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
672+
not greater than `1`).
673+
negative_prompt_2 (`str` or `List[str]`, *optional*):
674+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
675+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
676+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
677+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
669678
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
670679
The height in pixels of the generated image. This is set to 1024 by default for the best results.
671680
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -709,6 +718,14 @@ def __call__(
709718
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
710719
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
711720
provided, embeddings are computed from the `ip_adapter_image` input argument.
721+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
722+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
723+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
724+
argument.
725+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
726+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
727+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
728+
input argument.
712729
output_type (`str`, *optional*, defaults to `"pil"`):
713730
The output format of the generate image. Choose between
714731
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -773,7 +790,10 @@ def __call__(
773790
lora_scale = (
774791
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
775792
)
776-
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
793+
has_neg_prompt = negative_prompt is not None or (
794+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
795+
)
796+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
777797
(
778798
prompt_embeds,
779799
pooled_prompt_embeds,

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def __call__(
769769
if not self.vae.config.timestep_conditioning:
770770
timestep = None
771771
else:
772-
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
772+
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
773773
if not isinstance(decode_timestep, list):
774774
decode_timestep = [decode_timestep] * batch_size
775775
if decode_noise_scale is None:

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,35 @@ def __init__(
183183
pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
184184
)
185185

186+
def enable_vae_slicing(self):
187+
r"""
188+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
189+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
190+
"""
191+
self.vae.enable_slicing()
192+
193+
def disable_vae_slicing(self):
194+
r"""
195+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
196+
computing decoding in one step.
197+
"""
198+
self.vae.disable_slicing()
199+
200+
def enable_vae_tiling(self):
201+
r"""
202+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
203+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
204+
processing larger images.
205+
"""
206+
self.vae.enable_tiling()
207+
208+
def disable_vae_tiling(self):
209+
r"""
210+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
211+
computing decoding in one step.
212+
"""
213+
self.vae.disable_tiling()
214+
186215
def encode_prompt(
187216
self,
188217
prompt: Union[str, List[str]],

src/diffusers/pipelines/pag/pipeline_pag_sd_3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,9 @@ def encode_prompt(
375375
negative_prompt_2 (`str` or `List[str]`, *optional*):
376376
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
377377
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
378-
negative_prompt_2 (`str` or `List[str]`, *optional*):
378+
negative_prompt_3 (`str` or `List[str]`, *optional*):
379379
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
380-
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
380+
`text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
381381
prompt_embeds (`torch.FloatTensor`, *optional*):
382382
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
383383
provided, text embeddings will be generated from `prompt` input argument.

src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,9 @@ def encode_prompt(
391391
negative_prompt_2 (`str` or `List[str]`, *optional*):
392392
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
393393
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
394-
negative_prompt_2 (`str` or `List[str]`, *optional*):
394+
negative_prompt_3 (`str` or `List[str]`, *optional*):
395395
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
396-
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
396+
`text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
397397
prompt_embeds (`torch.FloatTensor`, *optional*):
398398
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
399399
provided, text embeddings will be generated from `prompt` input argument.

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,35 @@ def __init__(
218218
)
219219
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
220220

221+
def enable_vae_slicing(self):
222+
r"""
223+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
224+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
225+
"""
226+
self.vae.enable_slicing()
227+
228+
def disable_vae_slicing(self):
229+
r"""
230+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
231+
computing decoding in one step.
232+
"""
233+
self.vae.disable_slicing()
234+
235+
def enable_vae_tiling(self):
236+
r"""
237+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
238+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
239+
processing larger images.
240+
"""
241+
self.vae.enable_tiling()
242+
243+
def disable_vae_tiling(self):
244+
r"""
245+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
246+
computing decoding in one step.
247+
"""
248+
self.vae.disable_tiling()
249+
221250
def encode_prompt(
222251
self,
223252
prompt: Union[str, List[str]],

0 commit comments

Comments
 (0)