Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,10 +1010,12 @@ def __init__(
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
self.tile_sample_min_num_frames = 16

# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
Expand Down Expand Up @@ -1114,6 +1116,53 @@ def encode(
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)

def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would move this down a few methods to where blend_h and blend_v are located

blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b

def _temporal_tiled_decode(self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the debug statements from this method and move it below tiled_decode.

I think we would also need to implement tiled_encode. Happy to help with the changes if needed 🤗

batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames

row = []
for i in range(0, num_frames, tile_latent_stride_num_frames):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
else:
print("NOT Use tile decode")
print(f"input tile: {tile.size()}")
decoded = self.decoder(tile, temb)
print(f"output tile: {decoded.size()}")
if i > 0:
decoded = decoded[:, :, :-1, :, :]
row.append(decoded)

result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
result_row.append(tile)
else:
result_row.append(tile[:, :, :self.tile_sample_stride_num_frames + 1, :, :])

dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]

if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

def _decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
Expand All @@ -1125,13 +1174,8 @@ def _decode(
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, temb, return_dict=return_dict)

if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
dec = self._temporal_tiled_decode(z, temb, return_dict=False)[0]
else:
dec = self.decoder(z, temb)

Expand Down
Loading