Skip to content

Commit 64a0849

Browse files
committed
add framewise encode, refactor tiled encode/decode
1 parent ec918b9 commit 64a0849

File tree

1 file changed

+92
-82
lines changed

1 file changed

+92
-82
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 92 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,8 +1025,10 @@ def enable_tiling(
10251025
self,
10261026
tile_sample_min_height: Optional[int] = None,
10271027
tile_sample_min_width: Optional[int] = None,
1028+
tile_sample_min_num_frames: Optional[int] = None,
10281029
tile_sample_stride_height: Optional[float] = None,
10291030
tile_sample_stride_width: Optional[float] = None,
1031+
tile_sample_stride_num_frames: Optional[float] = None,
10301032
) -> None:
10311033
r"""
10321034
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -1048,8 +1050,10 @@ def enable_tiling(
10481050
self.use_tiling = True
10491051
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
10501052
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1053+
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
10511054
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
10521055
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
1056+
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
10531057

10541058
def disable_tiling(self) -> None:
10551059
r"""
@@ -1075,18 +1079,13 @@ def disable_slicing(self) -> None:
10751079
def _encode(self, x: torch.Tensor) -> torch.Tensor:
10761080
batch_size, num_channels, num_frames, height, width = x.shape
10771081

1082+
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
1083+
return self._temporal_tiled_encode(x)
1084+
10781085
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
10791086
return self.tiled_encode(x)
10801087

1081-
if self.use_framewise_encoding:
1082-
# TODO(aryan): requires investigation
1083-
raise NotImplementedError(
1084-
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1085-
"quality issues caused by splitting inference across frame dimension. If you believe this "
1086-
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1087-
)
1088-
else:
1089-
enc = self.encoder(x)
1088+
enc = self.encoder(x)
10901089

10911090
return enc
10921091

@@ -1116,53 +1115,6 @@ def encode(
11161115
if not return_dict:
11171116
return (posterior,)
11181117
return AutoencoderKLOutput(latent_dist=posterior)
1119-
1120-
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1121-
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
1122-
for x in range(blend_extent):
1123-
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
1124-
x / blend_extent
1125-
)
1126-
return b
1127-
1128-
def _temporal_tiled_decode(self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1129-
batch_size, num_channels, num_frames, height, width = z.shape
1130-
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
1131-
1132-
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1133-
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1134-
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1135-
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1136-
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
1137-
1138-
row = []
1139-
for i in range(0, num_frames, tile_latent_stride_num_frames):
1140-
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
1141-
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
1142-
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
1143-
else:
1144-
print("NOT Use tile decode")
1145-
print(f"input tile: {tile.size()}")
1146-
decoded = self.decoder(tile, temb)
1147-
print(f"output tile: {decoded.size()}")
1148-
if i > 0:
1149-
decoded = decoded[:, :, :-1, :, :]
1150-
row.append(decoded)
1151-
1152-
result_row = []
1153-
for i, tile in enumerate(row):
1154-
if i > 0:
1155-
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1156-
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
1157-
result_row.append(tile)
1158-
else:
1159-
result_row.append(tile[:, :, :self.tile_sample_stride_num_frames + 1, :, :])
1160-
1161-
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
1162-
1163-
if not return_dict:
1164-
return (dec,)
1165-
return DecoderOutput(sample=dec)
11661118

11671119
def _decode(
11681120
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
@@ -1171,13 +1123,13 @@ def _decode(
11711123
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
11721124
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
11731125

1126+
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
1127+
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
1128+
11741129
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
11751130
return self.tiled_decode(z, temb, return_dict=return_dict)
11761131

1177-
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
1178-
dec = self._temporal_tiled_decode(z, temb, return_dict=False)[0]
1179-
else:
1180-
dec = self.decoder(z, temb)
1132+
dec = self.decoder(z, temb)
11811133

11821134
if not return_dict:
11831135
return (dec,)
@@ -1232,6 +1184,14 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
12321184
x / blend_extent
12331185
)
12341186
return b
1187+
1188+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1189+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
1190+
for x in range(blend_extent):
1191+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
1192+
x / blend_extent
1193+
)
1194+
return b
12351195

12361196
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12371197
r"""Encode a batch of images using a tiled encoder.
@@ -1261,17 +1221,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12611221
for i in range(0, height, self.tile_sample_stride_height):
12621222
row = []
12631223
for j in range(0, width, self.tile_sample_stride_width):
1264-
if self.use_framewise_encoding:
1265-
# TODO(aryan): requires investigation
1266-
raise NotImplementedError(
1267-
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1268-
"quality issues caused by splitting inference across frame dimension. If you believe this "
1269-
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1270-
)
1271-
else:
1272-
time = self.encoder(
1273-
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1274-
)
1224+
time = self.encoder(
1225+
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1226+
)
12751227

12761228
row.append(time)
12771229
rows.append(row)
@@ -1327,17 +1279,9 @@ def tiled_decode(
13271279
for i in range(0, height, tile_latent_stride_height):
13281280
row = []
13291281
for j in range(0, width, tile_latent_stride_width):
1330-
if self.use_framewise_decoding:
1331-
# TODO(aryan): requires investigation
1332-
raise NotImplementedError(
1333-
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1334-
"quality issues caused by splitting inference across frame dimension. If you believe this "
1335-
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1336-
)
1337-
else:
1338-
time = self.decoder(
1339-
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
1340-
)
1282+
time = self.decoder(
1283+
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
1284+
)
13411285

13421286
row.append(time)
13431287
rows.append(row)
@@ -1362,6 +1306,72 @@ def tiled_decode(
13621306

13631307
return DecoderOutput(sample=dec)
13641308

1309+
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
1310+
batch_size, num_channels, num_frames, height, width = x.shape
1311+
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
1312+
1313+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1314+
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1315+
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
1316+
1317+
row = []
1318+
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
1319+
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
1320+
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
1321+
tile = self.tiled_encode(tile)
1322+
else:
1323+
tile = self.encoder(tile)
1324+
if i > 0:
1325+
tile = tile[:, :, 1:, :, :]
1326+
row.append(tile)
1327+
1328+
result_row = []
1329+
for i, tile in enumerate(row):
1330+
if i > 0:
1331+
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1332+
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
1333+
else:
1334+
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
1335+
1336+
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
1337+
return enc
1338+
1339+
def _temporal_tiled_decode(self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1340+
batch_size, num_channels, num_frames, height, width = z.shape
1341+
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
1342+
1343+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1344+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1345+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1346+
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1347+
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
1348+
1349+
row = []
1350+
for i in range(0, num_frames, tile_latent_stride_num_frames):
1351+
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
1352+
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
1353+
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
1354+
else:
1355+
decoded = self.decoder(tile, temb)
1356+
if i > 0:
1357+
decoded = decoded[:, :, :-1, :, :]
1358+
row.append(decoded)
1359+
1360+
result_row = []
1361+
for i, tile in enumerate(row):
1362+
if i > 0:
1363+
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1364+
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
1365+
result_row.append(tile)
1366+
else:
1367+
result_row.append(tile[:, :, :self.tile_sample_stride_num_frames + 1, :, :])
1368+
1369+
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
1370+
1371+
if not return_dict:
1372+
return (dec,)
1373+
return DecoderOutput(sample=dec)
1374+
13651375
def forward(
13661376
self,
13671377
sample: torch.Tensor,

0 commit comments

Comments
 (0)