Skip to content

Commit e79162c

Browse files
committed
add sanity test tiling for ltx
1 parent 64a0849 commit e79162c

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -998,8 +998,8 @@ def __init__(
998998

999999
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
10001000
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
1001-
self.use_framewise_encoding = False
1002-
self.use_framewise_decoding = False
1001+
self.use_framewise_encoding = True
1002+
self.use_framewise_decoding = True
10031003

10041004
# This can be configured based on the amount of GPU memory available.
10051005
# `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
@@ -1122,6 +1122,7 @@ def _decode(
11221122
batch_size, num_channels, num_frames, height, width = z.shape
11231123
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
11241124
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1125+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
11251126

11261127
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
11271128
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
@@ -1388,5 +1389,5 @@ def forward(
13881389
z = posterior.mode()
13891390
dec = self.decode(z, temb)
13901391
if not return_dict:
1391-
return (dec,)
1392+
return (dec.sample,)
13921393
return dec

tests/models/autoencoders/test_models_autoencoder_ltx_video.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,34 @@ def test_outputs_equivalence(self):
167167
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
168168
def test_forward_with_norm_groups(self):
169169
pass
170+
171+
def test_enable_disable_tiling(self):
172+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
173+
174+
torch.manual_seed(0)
175+
model = self.model_class(**init_dict).to(torch_device)
176+
177+
inputs_dict.update({"return_dict": False})
178+
179+
torch.manual_seed(0)
180+
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
181+
182+
torch.manual_seed(0)
183+
model.enable_tiling()
184+
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
185+
186+
self.assertLess(
187+
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
188+
0.5,
189+
"VAE tiling should not affect the inference results",
190+
)
191+
192+
torch.manual_seed(0)
193+
model.disable_tiling()
194+
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
195+
196+
self.assertEqual(
197+
output_without_tiling.detach().cpu().numpy().all(),
198+
output_without_tiling_2.detach().cpu().numpy().all(),
199+
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
200+
)

0 commit comments

Comments
 (0)