Skip to content

Commit aea7483

Browse files
authored
Merge branch 'main' into properly-skip-tests-ii
2 parents 9394a19 + e7db062 commit aea7483

File tree

8 files changed

+196
-5
lines changed

8 files changed

+196
-5
lines changed

examples/community/rerender_a_video.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,9 @@ def __call__(
908908
if callback is not None and i % callback_steps == 0:
909909
callback(i, t, latents)
910910

911+
if XLA_AVAILABLE:
912+
xm.mark_step()
913+
911914
if not output_type == "latent":
912915
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
913916
else:

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,9 @@ def __init__(
486486
self.tile_sample_stride_height = 448
487487
self.tile_sample_stride_width = 448
488488

489+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
490+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
491+
489492
def enable_tiling(
490493
self,
491494
tile_sample_min_height: Optional[int] = None,
@@ -515,6 +518,8 @@ def enable_tiling(
515518
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
516519
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
517520
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
521+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
522+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
518523

519524
def disable_tiling(self) -> None:
520525
r"""
@@ -606,11 +611,106 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
606611
return (decoded,)
607612
return DecoderOutput(sample=decoded)
608613

614+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
615+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
616+
for y in range(blend_extent):
617+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
618+
return b
619+
620+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
621+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
622+
for x in range(blend_extent):
623+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
624+
return b
625+
609626
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
610-
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
627+
batch_size, num_channels, height, width = x.shape
628+
latent_height = height // self.spatial_compression_ratio
629+
latent_width = width // self.spatial_compression_ratio
630+
631+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
632+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
633+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
634+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
635+
blend_height = tile_latent_min_height - tile_latent_stride_height
636+
blend_width = tile_latent_min_width - tile_latent_stride_width
637+
638+
# Split x into overlapping tiles and encode them separately.
639+
# The tiles have an overlap to avoid seams between tiles.
640+
rows = []
641+
for i in range(0, x.shape[2], self.tile_sample_stride_height):
642+
row = []
643+
for j in range(0, x.shape[3], self.tile_sample_stride_width):
644+
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
645+
if (
646+
tile.shape[2] % self.spatial_compression_ratio != 0
647+
or tile.shape[3] % self.spatial_compression_ratio != 0
648+
):
649+
pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
650+
pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
651+
tile = F.pad(tile, (0, pad_w, 0, pad_h))
652+
tile = self.encoder(tile)
653+
row.append(tile)
654+
rows.append(row)
655+
result_rows = []
656+
for i, row in enumerate(rows):
657+
result_row = []
658+
for j, tile in enumerate(row):
659+
# blend the above tile and the left tile
660+
# to the current tile and add the current tile to the result row
661+
if i > 0:
662+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
663+
if j > 0:
664+
tile = self.blend_h(row[j - 1], tile, blend_width)
665+
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
666+
result_rows.append(torch.cat(result_row, dim=3))
667+
668+
encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
669+
670+
if not return_dict:
671+
return (encoded,)
672+
return EncoderOutput(latent=encoded)
611673

612674
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
613-
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
675+
batch_size, num_channels, height, width = z.shape
676+
677+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
678+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
679+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
680+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
681+
682+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
683+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
684+
685+
# Split z into overlapping tiles and decode them separately.
686+
# The tiles have an overlap to avoid seams between tiles.
687+
rows = []
688+
for i in range(0, height, tile_latent_stride_height):
689+
row = []
690+
for j in range(0, width, tile_latent_stride_width):
691+
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
692+
decoded = self.decoder(tile)
693+
row.append(decoded)
694+
rows.append(row)
695+
696+
result_rows = []
697+
for i, row in enumerate(rows):
698+
result_row = []
699+
for j, tile in enumerate(row):
700+
# blend the above tile and the left tile
701+
# to the current tile and add the current tile to the result row
702+
if i > 0:
703+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
704+
if j > 0:
705+
tile = self.blend_h(row[j - 1], tile, blend_width)
706+
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
707+
result_rows.append(torch.cat(result_row, dim=3))
708+
709+
decoded = torch.cat(result_rows, dim=2)
710+
711+
if not return_dict:
712+
return (decoded,)
713+
return DecoderOutput(sample=decoded)
614714

615715
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
616716
encoded = self.encode(sample, return_dict=False)[0]

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,35 @@ def __init__(
183183
pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
184184
)
185185

186+
def enable_vae_slicing(self):
187+
r"""
188+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
189+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
190+
"""
191+
self.vae.enable_slicing()
192+
193+
def disable_vae_slicing(self):
194+
r"""
195+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
196+
computing decoding in one step.
197+
"""
198+
self.vae.disable_slicing()
199+
200+
def enable_vae_tiling(self):
201+
r"""
202+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
203+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
204+
processing larger images.
205+
"""
206+
self.vae.enable_tiling()
207+
208+
def disable_vae_tiling(self):
209+
r"""
210+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
211+
computing decoding in one step.
212+
"""
213+
self.vae.disable_tiling()
214+
186215
def encode_prompt(
187216
self,
188217
prompt: Union[str, List[str]],

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,35 @@ def __init__(
218218
)
219219
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
220220

221+
def enable_vae_slicing(self):
222+
r"""
223+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
224+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
225+
"""
226+
self.vae.enable_slicing()
227+
228+
def disable_vae_slicing(self):
229+
r"""
230+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
231+
computing decoding in one step.
232+
"""
233+
self.vae.disable_slicing()
234+
235+
def enable_vae_tiling(self):
236+
r"""
237+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
238+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
239+
processing larger images.
240+
"""
241+
self.vae.enable_tiling()
242+
243+
def disable_vae_tiling(self):
244+
r"""
245+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
246+
computing decoding in one step.
247+
"""
248+
self.vae.disable_tiling()
249+
221250
def encode_prompt(
222251
self,
223252
prompt: Union[str, List[str]],

src/diffusers/schedulers/scheduling_ddim_inverse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
266266

267267
self.num_inference_steps = num_inference_steps
268268

269-
# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
269+
# "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
270270
if self.config.timestep_spacing == "leading":
271271
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
272272
# creates integer timesteps by multiplying by ratio

tests/lora/test_lora_layers_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_sd3_img2img_lora(self):
177177

178178
image = pipe(**inputs).images[0]
179179
image_slice = image[0, -3:, -3:]
180-
expected_slice = np.array([0.5396, 0.5776, 0.7432, 0.5151, 0.5586, 0.7383, 0.5537, 0.5933, 0.7153])
180+
expected_slice = np.array([0.5649, 0.5405, 0.5488, 0.5688, 0.5449, 0.5513, 0.5337, 0.5107, 0.5059])
181181

182182
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
183183

tests/pipelines/sana/test_sana.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,36 @@ def test_attention_slicing_forward_pass(
254254
"Attention slicing should not affect the inference results",
255255
)
256256

257+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
258+
generator_device = "cpu"
259+
components = self.get_dummy_components()
260+
261+
pipe = self.pipeline_class(**components)
262+
pipe.to("cpu")
263+
pipe.set_progress_bar_config(disable=None)
264+
265+
# Without tiling
266+
inputs = self.get_dummy_inputs(generator_device)
267+
inputs["height"] = inputs["width"] = 128
268+
output_without_tiling = pipe(**inputs)[0]
269+
270+
# With tiling
271+
pipe.vae.enable_tiling(
272+
tile_sample_min_height=96,
273+
tile_sample_min_width=96,
274+
tile_sample_stride_height=64,
275+
tile_sample_stride_width=64,
276+
)
277+
inputs = self.get_dummy_inputs(generator_device)
278+
inputs["height"] = inputs["width"] = 128
279+
output_with_tiling = pipe(**inputs)[0]
280+
281+
self.assertLess(
282+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
283+
expected_diff_max,
284+
"VAE tiling should not affect the inference results",
285+
)
286+
257287
# TODO(aryan): Create a dummy gemma model with smol vocab size
258288
@unittest.skip(
259289
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def test_quality(self):
372372
output_type="np",
373373
).images
374374
out_slice = output[0, -3:, -3:, -1].flatten()
375-
expected_slice = np.array([0.0376, 0.0359, 0.0015, 0.0449, 0.0479, 0.0098, 0.0083, 0.0295, 0.0295])
375+
expected_slice = np.array([0.0674, 0.0623, 0.0364, 0.0632, 0.0671, 0.0430, 0.0317, 0.0493, 0.0583])
376376

377377
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
378378
self.assertTrue(max_diff < 1e-2)

0 commit comments

Comments
 (0)