@@ -1010,10 +1010,12 @@ def __init__(
10101010 # The minimal tile height and width for spatial tiling to be used
10111011 self .tile_sample_min_height = 512
10121012 self .tile_sample_min_width = 512
1013+ self .tile_sample_min_num_frames = 16
10131014
10141015 # The minimal distance between two spatial tiles
10151016 self .tile_sample_stride_height = 448
10161017 self .tile_sample_stride_width = 448
1018+ self .tile_sample_stride_num_frames = 8
10171019
10181020 def _set_gradient_checkpointing (self , module , value = False ):
10191021 if isinstance (module , (LTXVideoEncoder3d , LTXVideoDecoder3d )):
@@ -1023,8 +1025,10 @@ def enable_tiling(
10231025 self ,
10241026 tile_sample_min_height : Optional [int ] = None ,
10251027 tile_sample_min_width : Optional [int ] = None ,
1028+ tile_sample_min_num_frames : Optional [int ] = None ,
10261029 tile_sample_stride_height : Optional [float ] = None ,
10271030 tile_sample_stride_width : Optional [float ] = None ,
1031+ tile_sample_stride_num_frames : Optional [float ] = None ,
10281032 ) -> None :
10291033 r"""
10301034 Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -1046,8 +1050,10 @@ def enable_tiling(
10461050 self .use_tiling = True
10471051 self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
10481052 self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
1053+ self .tile_sample_min_num_frames = tile_sample_min_num_frames or self .tile_sample_min_num_frames
10491054 self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
10501055 self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
1056+ self .tile_sample_stride_num_frames = tile_sample_stride_num_frames or self .tile_sample_stride_num_frames
10511057
10521058 def disable_tiling (self ) -> None :
10531059 r"""
@@ -1073,18 +1079,13 @@ def disable_slicing(self) -> None:
10731079 def _encode (self , x : torch .Tensor ) -> torch .Tensor :
10741080 batch_size , num_channels , num_frames , height , width = x .shape
10751081
1082+ if self .use_framewise_decoding and num_frames > self .tile_sample_min_num_frames :
1083+ return self ._temporal_tiled_encode (x )
1084+
10761085 if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
10771086 return self .tiled_encode (x )
10781087
1079- if self .use_framewise_encoding :
1080- # TODO(aryan): requires investigation
1081- raise NotImplementedError (
1082- "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1083- "quality issues caused by splitting inference across frame dimension. If you believe this "
1084- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1085- )
1086- else :
1087- enc = self .encoder (x )
1088+ enc = self .encoder (x )
10881089
10891090 return enc
10901091
@@ -1121,19 +1122,15 @@ def _decode(
11211122 batch_size , num_channels , num_frames , height , width = z .shape
11221123 tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
11231124 tile_latent_min_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1125+ tile_latent_min_num_frames = self .tile_sample_min_num_frames // self .temporal_compression_ratio
1126+
1127+ if self .use_framewise_decoding and num_frames > tile_latent_min_num_frames :
1128+ return self ._temporal_tiled_decode (z , temb , return_dict = return_dict )
11241129
11251130 if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
11261131 return self .tiled_decode (z , temb , return_dict = return_dict )
11271132
1128- if self .use_framewise_decoding :
1129- # TODO(aryan): requires investigation
1130- raise NotImplementedError (
1131- "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1132- "quality issues caused by splitting inference across frame dimension. If you believe this "
1133- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1134- )
1135- else :
1136- dec = self .decoder (z , temb )
1133+ dec = self .decoder (z , temb )
11371134
11381135 if not return_dict :
11391136 return (dec ,)
@@ -1189,6 +1186,14 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
11891186 )
11901187 return b
11911188
1189+ def blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
1190+ blend_extent = min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
1191+ for x in range (blend_extent ):
1192+ b [:, :, x , :, :] = a [:, :, - blend_extent + x , :, :] * (1 - x / blend_extent ) + b [:, :, x , :, :] * (
1193+ x / blend_extent
1194+ )
1195+ return b
1196+
11921197 def tiled_encode (self , x : torch .Tensor ) -> torch .Tensor :
11931198 r"""Encode a batch of images using a tiled encoder.
11941199
@@ -1217,17 +1222,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12171222 for i in range (0 , height , self .tile_sample_stride_height ):
12181223 row = []
12191224 for j in range (0 , width , self .tile_sample_stride_width ):
1220- if self .use_framewise_encoding :
1221- # TODO(aryan): requires investigation
1222- raise NotImplementedError (
1223- "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1224- "quality issues caused by splitting inference across frame dimension. If you believe this "
1225- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1226- )
1227- else :
1228- time = self .encoder (
1229- x [:, :, :, i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
1230- )
1225+ time = self .encoder (
1226+ x [:, :, :, i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
1227+ )
12311228
12321229 row .append (time )
12331230 rows .append (row )
@@ -1283,17 +1280,7 @@ def tiled_decode(
12831280 for i in range (0 , height , tile_latent_stride_height ):
12841281 row = []
12851282 for j in range (0 , width , tile_latent_stride_width ):
1286- if self .use_framewise_decoding :
1287- # TODO(aryan): requires investigation
1288- raise NotImplementedError (
1289- "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1290- "quality issues caused by splitting inference across frame dimension. If you believe this "
1291- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1292- )
1293- else :
1294- time = self .decoder (
1295- z [:, :, :, i : i + tile_latent_min_height , j : j + tile_latent_min_width ], temb
1296- )
1283+ time = self .decoder (z [:, :, :, i : i + tile_latent_min_height , j : j + tile_latent_min_width ], temb )
12971284
12981285 row .append (time )
12991286 rows .append (row )
@@ -1318,6 +1305,74 @@ def tiled_decode(
13181305
13191306 return DecoderOutput (sample = dec )
13201307
1308+ def _temporal_tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
1309+ batch_size , num_channels , num_frames , height , width = x .shape
1310+ latent_num_frames = (num_frames - 1 ) // self .temporal_compression_ratio + 1
1311+
1312+ tile_latent_min_num_frames = self .tile_sample_min_num_frames // self .temporal_compression_ratio
1313+ tile_latent_stride_num_frames = self .tile_sample_stride_num_frames // self .temporal_compression_ratio
1314+ blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
1315+
1316+ row = []
1317+ for i in range (0 , num_frames , self .tile_sample_stride_num_frames ):
1318+ tile = x [:, :, i : i + self .tile_sample_min_num_frames + 1 , :, :]
1319+ if self .use_tiling and (height > self .tile_sample_min_height or width > self .tile_sample_min_width ):
1320+ tile = self .tiled_encode (tile )
1321+ else :
1322+ tile = self .encoder (tile )
1323+ if i > 0 :
1324+ tile = tile [:, :, 1 :, :, :]
1325+ row .append (tile )
1326+
1327+ result_row = []
1328+ for i , tile in enumerate (row ):
1329+ if i > 0 :
1330+ tile = self .blend_t (row [i - 1 ], tile , blend_num_frames )
1331+ result_row .append (tile [:, :, :tile_latent_stride_num_frames , :, :])
1332+ else :
1333+ result_row .append (tile [:, :, : tile_latent_stride_num_frames + 1 , :, :])
1334+
1335+ enc = torch .cat (result_row , dim = 2 )[:, :, :latent_num_frames ]
1336+ return enc
1337+
1338+ def _temporal_tiled_decode (
1339+ self , z : torch .Tensor , temb : Optional [torch .Tensor ], return_dict : bool = True
1340+ ) -> Union [DecoderOutput , torch .Tensor ]:
1341+ batch_size , num_channels , num_frames , height , width = z .shape
1342+ num_sample_frames = (num_frames - 1 ) * self .temporal_compression_ratio + 1
1343+
1344+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1345+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1346+ tile_latent_min_num_frames = self .tile_sample_min_num_frames // self .temporal_compression_ratio
1347+ tile_latent_stride_num_frames = self .tile_sample_stride_num_frames // self .temporal_compression_ratio
1348+ blend_num_frames = self .tile_sample_min_num_frames - self .tile_sample_stride_num_frames
1349+
1350+ row = []
1351+ for i in range (0 , num_frames , tile_latent_stride_num_frames ):
1352+ tile = z [:, :, i : i + tile_latent_min_num_frames + 1 , :, :]
1353+ if self .use_tiling and (tile .shape [- 1 ] > tile_latent_min_width or tile .shape [- 2 ] > tile_latent_min_height ):
1354+ decoded = self .tiled_decode (tile , temb , return_dict = True ).sample
1355+ else :
1356+ decoded = self .decoder (tile , temb )
1357+ if i > 0 :
1358+ decoded = decoded [:, :, :- 1 , :, :]
1359+ row .append (decoded )
1360+
1361+ result_row = []
1362+ for i , tile in enumerate (row ):
1363+ if i > 0 :
1364+ tile = self .blend_t (row [i - 1 ], tile , blend_num_frames )
1365+ tile = tile [:, :, : self .tile_sample_stride_num_frames , :, :]
1366+ result_row .append (tile )
1367+ else :
1368+ result_row .append (tile [:, :, : self .tile_sample_stride_num_frames + 1 , :, :])
1369+
1370+ dec = torch .cat (result_row , dim = 2 )[:, :, :num_sample_frames ]
1371+
1372+ if not return_dict :
1373+ return (dec ,)
1374+ return DecoderOutput (sample = dec )
1375+
13211376 def forward (
13221377 self ,
13231378 sample : torch .Tensor ,
@@ -1334,5 +1389,5 @@ def forward(
13341389 z = posterior .mode ()
13351390 dec = self .decode (z , temb )
13361391 if not return_dict :
1337- return (dec ,)
1392+ return (dec . sample ,)
13381393 return dec
0 commit comments