Skip to content

Commit ec918b9

Browse files
author
Pham Hong Vinh
committed
add framewise decode
1 parent 811560b commit ec918b9

File tree

1 file changed

+51
-7
lines changed

1 file changed

+51
-7
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,10 +1010,12 @@ def __init__(
10101010
# The minimal tile height and width for spatial tiling to be used
10111011
self.tile_sample_min_height = 512
10121012
self.tile_sample_min_width = 512
1013+
self.tile_sample_min_num_frames = 16
10131014

10141015
# The minimal distance between two spatial tiles
10151016
self.tile_sample_stride_height = 448
10161017
self.tile_sample_stride_width = 448
1018+
self.tile_sample_stride_num_frames = 8
10171019

10181020
def _set_gradient_checkpointing(self, module, value=False):
10191021
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
@@ -1114,6 +1116,53 @@ def encode(
11141116
if not return_dict:
11151117
return (posterior,)
11161118
return AutoencoderKLOutput(latent_dist=posterior)
1119+
1120+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1121+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
1122+
for x in range(blend_extent):
1123+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
1124+
x / blend_extent
1125+
)
1126+
return b
1127+
1128+
def _temporal_tiled_decode(self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1129+
batch_size, num_channels, num_frames, height, width = z.shape
1130+
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
1131+
1132+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1133+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1134+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1135+
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1136+
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
1137+
1138+
row = []
1139+
for i in range(0, num_frames, tile_latent_stride_num_frames):
1140+
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
1141+
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
1142+
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
1143+
else:
1144+
print("NOT Use tile decode")
1145+
print(f"input tile: {tile.size()}")
1146+
decoded = self.decoder(tile, temb)
1147+
print(f"output tile: {decoded.size()}")
1148+
if i > 0:
1149+
decoded = decoded[:, :, :-1, :, :]
1150+
row.append(decoded)
1151+
1152+
result_row = []
1153+
for i, tile in enumerate(row):
1154+
if i > 0:
1155+
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1156+
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
1157+
result_row.append(tile)
1158+
else:
1159+
result_row.append(tile[:, :, :self.tile_sample_stride_num_frames + 1, :, :])
1160+
1161+
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
1162+
1163+
if not return_dict:
1164+
return (dec,)
1165+
return DecoderOutput(sample=dec)
11171166

11181167
def _decode(
11191168
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
@@ -1125,13 +1174,8 @@ def _decode(
11251174
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
11261175
return self.tiled_decode(z, temb, return_dict=return_dict)
11271176

1128-
if self.use_framewise_decoding:
1129-
# TODO(aryan): requires investigation
1130-
raise NotImplementedError(
1131-
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1132-
"quality issues caused by splitting inference across frame dimension. If you believe this "
1133-
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1134-
)
1177+
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
1178+
dec = self._temporal_tiled_decode(z, temb, return_dict=False)[0]
11351179
else:
11361180
dec = self.decoder(z, temb)
11371181

0 commit comments

Comments
 (0)