Skip to content

Commit ab2cb59

Browse files
authored
Merge branch 'huggingface:main' into main
2 parents 514235a + 36acdd7 commit ab2cb59

21 files changed

+230
-6
lines changed

examples/community/rerender_a_video.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,9 @@ def __call__(
908908
if callback is not None and i % callback_steps == 0:
909909
callback(i, t, latents)
910910

911+
if XLA_AVAILABLE:
912+
xm.mark_step()
913+
911914
if not output_type == "latent":
912915
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
913916
else:

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,9 @@ def __init__(
486486
self.tile_sample_stride_height = 448
487487
self.tile_sample_stride_width = 448
488488

489+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
490+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
491+
489492
def enable_tiling(
490493
self,
491494
tile_sample_min_height: Optional[int] = None,
@@ -515,6 +518,8 @@ def enable_tiling(
515518
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
516519
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
517520
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
521+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
522+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
518523

519524
def disable_tiling(self) -> None:
520525
r"""
@@ -606,11 +611,106 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
606611
return (decoded,)
607612
return DecoderOutput(sample=decoded)
608613

614+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
615+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
616+
for y in range(blend_extent):
617+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
618+
return b
619+
620+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
621+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
622+
for x in range(blend_extent):
623+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
624+
return b
625+
609626
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
610-
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
627+
batch_size, num_channels, height, width = x.shape
628+
latent_height = height // self.spatial_compression_ratio
629+
latent_width = width // self.spatial_compression_ratio
630+
631+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
632+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
633+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
634+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
635+
blend_height = tile_latent_min_height - tile_latent_stride_height
636+
blend_width = tile_latent_min_width - tile_latent_stride_width
637+
638+
# Split x into overlapping tiles and encode them separately.
639+
# The tiles have an overlap to avoid seams between tiles.
640+
rows = []
641+
for i in range(0, x.shape[2], self.tile_sample_stride_height):
642+
row = []
643+
for j in range(0, x.shape[3], self.tile_sample_stride_width):
644+
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
645+
if (
646+
tile.shape[2] % self.spatial_compression_ratio != 0
647+
or tile.shape[3] % self.spatial_compression_ratio != 0
648+
):
649+
pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
650+
pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
651+
tile = F.pad(tile, (0, pad_w, 0, pad_h))
652+
tile = self.encoder(tile)
653+
row.append(tile)
654+
rows.append(row)
655+
result_rows = []
656+
for i, row in enumerate(rows):
657+
result_row = []
658+
for j, tile in enumerate(row):
659+
# blend the above tile and the left tile
660+
# to the current tile and add the current tile to the result row
661+
if i > 0:
662+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
663+
if j > 0:
664+
tile = self.blend_h(row[j - 1], tile, blend_width)
665+
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
666+
result_rows.append(torch.cat(result_row, dim=3))
667+
668+
encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
669+
670+
if not return_dict:
671+
return (encoded,)
672+
return EncoderOutput(latent=encoded)
611673

612674
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
613-
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
675+
batch_size, num_channels, height, width = z.shape
676+
677+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
678+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
679+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
680+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
681+
682+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
683+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
684+
685+
# Split z into overlapping tiles and decode them separately.
686+
# The tiles have an overlap to avoid seams between tiles.
687+
rows = []
688+
for i in range(0, height, tile_latent_stride_height):
689+
row = []
690+
for j in range(0, width, tile_latent_stride_width):
691+
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
692+
decoded = self.decoder(tile)
693+
row.append(decoded)
694+
rows.append(row)
695+
696+
result_rows = []
697+
for i, row in enumerate(rows):
698+
result_row = []
699+
for j, tile in enumerate(row):
700+
# blend the above tile and the left tile
701+
# to the current tile and add the current tile to the result row
702+
if i > 0:
703+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
704+
if j > 0:
705+
tile = self.blend_h(row[j - 1], tile, blend_width)
706+
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
707+
result_rows.append(torch.cat(result_row, dim=3))
708+
709+
decoded = torch.cat(result_rows, dim=2)
710+
711+
if not return_dict:
712+
return (decoded,)
713+
return DecoderOutput(sample=decoded)
614714

615715
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
616716
encoded = self.encode(sample, return_dict=False)[0]

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,35 @@ def __init__(
184184
pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
185185
)
186186

187+
def enable_vae_slicing(self):
188+
r"""
189+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
190+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
191+
"""
192+
self.vae.enable_slicing()
193+
194+
def disable_vae_slicing(self):
195+
r"""
196+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
197+
computing decoding in one step.
198+
"""
199+
self.vae.disable_slicing()
200+
201+
def enable_vae_tiling(self):
202+
r"""
203+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
204+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
205+
processing larger images.
206+
"""
207+
self.vae.enable_tiling()
208+
209+
def disable_vae_tiling(self):
210+
r"""
211+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
212+
computing decoding in one step.
213+
"""
214+
self.vae.disable_tiling()
215+
187216
def encode_prompt(
188217
self,
189218
prompt: Union[str, List[str]],

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,35 @@ def __init__(
218218
)
219219
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
220220

221+
def enable_vae_slicing(self):
222+
r"""
223+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
224+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
225+
"""
226+
self.vae.enable_slicing()
227+
228+
def disable_vae_slicing(self):
229+
r"""
230+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
231+
computing decoding in one step.
232+
"""
233+
self.vae.disable_slicing()
234+
235+
def enable_vae_tiling(self):
236+
r"""
237+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
238+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
239+
processing larger images.
240+
"""
241+
self.vae.enable_tiling()
242+
243+
def disable_vae_tiling(self):
244+
r"""
245+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
246+
computing decoding in one step.
247+
"""
248+
self.vae.disable_tiling()
249+
221250
def encode_prompt(
222251
self,
223252
prompt: Union[str, List[str]],

src/diffusers/schedulers/scheduling_ddim_inverse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
266266

267267
self.num_inference_steps = num_inference_steps
268268

269-
# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
269+
# "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
270270
if self.config.timestep_spacing == "leading":
271271
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
272272
# creates integer timesteps by multiplying by ratio

tests/models/autoencoders/test_models_vq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ def prepare_init_args_and_inputs_for_common(self):
6565
inputs_dict = self.dummy_input
6666
return init_dict, inputs_dict
6767

68+
@unittest.skip("Test not supported.")
6869
def test_forward_signature(self):
6970
pass
7071

72+
@unittest.skip("Test not supported.")
7173
def test_training(self):
7274
pass
7375

tests/models/unets/test_models_unet_1d.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@ def input_shape(self):
5151
def output_shape(self):
5252
return (4, 14, 16)
5353

54+
@unittest.skip("Test not supported.")
5455
def test_ema_training(self):
5556
pass
5657

58+
@unittest.skip("Test not supported.")
5759
def test_training(self):
5860
pass
5961

@@ -126,6 +128,7 @@ def test_output_pretrained(self):
126128
# fmt: on
127129
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
128130

131+
@unittest.skip("Test not supported.")
129132
def test_forward_with_norm_groups(self):
130133
# Not implemented yet for this UNet
131134
pass
@@ -205,9 +208,11 @@ def test_output(self):
205208
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
206209
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
207210

211+
@unittest.skip("Test not supported.")
208212
def test_ema_training(self):
209213
pass
210214

215+
@unittest.skip("Test not supported.")
211216
def test_training(self):
212217
pass
213218

@@ -265,6 +270,7 @@ def test_output_pretrained(self):
265270
# fmt: on
266271
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
267272

273+
@unittest.skip("Test not supported.")
268274
def test_forward_with_norm_groups(self):
269275
# Not implemented yet for this UNet
270276
pass

tests/models/unets/test_models_unet_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def test_output_pretrained_ve_large(self):
383383

384384
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
385385

386+
@unittest.skip("Test not supported.")
386387
def test_forward_with_norm_groups(self):
387388
# not required for this model
388389
pass

tests/models/unets/test_models_unet_controlnetxs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def test_time_embedding_mixing(self):
320320

321321
assert output.shape == output_mix_time.shape
322322

323+
@unittest.skip("Test not supported.")
323324
def test_forward_with_norm_groups(self):
324325
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
325326
pass

tests/pipelines/sana/test_sana.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,36 @@ def test_attention_slicing_forward_pass(
254254
"Attention slicing should not affect the inference results",
255255
)
256256

257+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
258+
generator_device = "cpu"
259+
components = self.get_dummy_components()
260+
261+
pipe = self.pipeline_class(**components)
262+
pipe.to("cpu")
263+
pipe.set_progress_bar_config(disable=None)
264+
265+
# Without tiling
266+
inputs = self.get_dummy_inputs(generator_device)
267+
inputs["height"] = inputs["width"] = 128
268+
output_without_tiling = pipe(**inputs)[0]
269+
270+
# With tiling
271+
pipe.vae.enable_tiling(
272+
tile_sample_min_height=96,
273+
tile_sample_min_width=96,
274+
tile_sample_stride_height=64,
275+
tile_sample_stride_width=64,
276+
)
277+
inputs = self.get_dummy_inputs(generator_device)
278+
inputs["height"] = inputs["width"] = 128
279+
output_with_tiling = pipe(**inputs)[0]
280+
281+
self.assertLess(
282+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
283+
expected_diff_max,
284+
"VAE tiling should not affect the inference results",
285+
)
286+
257287
# TODO(aryan): Create a dummy gemma model with smol vocab size
258288
@unittest.skip(
259289
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."

0 commit comments

Comments
 (0)