@@ -1025,8 +1025,10 @@ def enable_tiling(
10251025 self ,
10261026 tile_sample_min_height : Optional [int ] = None ,
10271027 tile_sample_min_width : Optional [int ] = None ,
1028+ tile_sample_min_num_frames : Optional [int ] = None ,
10281029 tile_sample_stride_height : Optional [float ] = None ,
10291030 tile_sample_stride_width : Optional [float ] = None ,
1031+ tile_sample_stride_num_frames : Optional [float ] = None ,
10301032 ) -> None :
10311033 r"""
10321034 Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -1048,8 +1050,10 @@ def enable_tiling(
10481050 self .use_tiling = True
10491051 self .tile_sample_min_height = tile_sample_min_height or self .tile_sample_min_height
10501052 self .tile_sample_min_width = tile_sample_min_width or self .tile_sample_min_width
1053+ self .tile_sample_min_num_frames = tile_sample_min_num_frames or self .tile_sample_min_num_frames
10511054 self .tile_sample_stride_height = tile_sample_stride_height or self .tile_sample_stride_height
10521055 self .tile_sample_stride_width = tile_sample_stride_width or self .tile_sample_stride_width
1056+ self .tile_sample_stride_num_frames = tile_sample_stride_num_frames or self .tile_sample_stride_num_frames
10531057
10541058 def disable_tiling (self ) -> None :
10551059 r"""
@@ -1075,18 +1079,13 @@ def disable_slicing(self) -> None:
10751079 def _encode (self , x : torch .Tensor ) -> torch .Tensor :
10761080 batch_size , num_channels , num_frames , height , width = x .shape
10771081
1082+ if self .use_framewise_decoding and num_frames > self .tile_sample_min_num_frames :
1083+ return self ._temporal_tiled_encode (x )
1084+
10781085 if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
10791086 return self .tiled_encode (x )
10801087
1081- if self .use_framewise_encoding :
1082- # TODO(aryan): requires investigation
1083- raise NotImplementedError (
1084- "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1085- "quality issues caused by splitting inference across frame dimension. If you believe this "
1086- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1087- )
1088- else :
1089- enc = self .encoder (x )
1088+ enc = self .encoder (x )
10901089
10911090 return enc
10921091
@@ -1116,53 +1115,6 @@ def encode(
11161115 if not return_dict :
11171116 return (posterior ,)
11181117 return AutoencoderKLOutput (latent_dist = posterior )
1119-
1120- def blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
1121- blend_extent = min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
1122- for x in range (blend_extent ):
1123- b [:, :, x , :, :] = a [:, :, - blend_extent + x , :, :] * (1 - x / blend_extent ) + b [:, :, x , :, :] * (
1124- x / blend_extent
1125- )
1126- return b
1127-
1128- def _temporal_tiled_decode (self , z : torch .Tensor , temb : Optional [torch .Tensor ], return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1129- batch_size , num_channels , num_frames , height , width = z .shape
1130- num_sample_frames = (num_frames - 1 ) * self .temporal_compression_ratio + 1
1131-
1132- tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1133- tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1134- tile_latent_min_num_frames = self .tile_sample_min_num_frames // self .temporal_compression_ratio
1135- tile_latent_stride_num_frames = self .tile_sample_stride_num_frames // self .temporal_compression_ratio
1136- blend_num_frames = self .tile_sample_min_num_frames - self .tile_sample_stride_num_frames
1137-
1138- row = []
1139- for i in range (0 , num_frames , tile_latent_stride_num_frames ):
1140- tile = z [:, :, i : i + tile_latent_min_num_frames + 1 , :, :]
1141- if self .use_tiling and (tile .shape [- 1 ] > tile_latent_min_width or tile .shape [- 2 ] > tile_latent_min_height ):
1142- decoded = self .tiled_decode (tile , temb , return_dict = True ).sample
1143- else :
1144- print ("NOT Use tile decode" )
1145- print (f"input tile: { tile .size ()} " )
1146- decoded = self .decoder (tile , temb )
1147- print (f"output tile: { decoded .size ()} " )
1148- if i > 0 :
1149- decoded = decoded [:, :, :- 1 , :, :]
1150- row .append (decoded )
1151-
1152- result_row = []
1153- for i , tile in enumerate (row ):
1154- if i > 0 :
1155- tile = self .blend_t (row [i - 1 ], tile , blend_num_frames )
1156- tile = tile [:, :, : self .tile_sample_stride_num_frames , :, :]
1157- result_row .append (tile )
1158- else :
1159- result_row .append (tile [:, :, :self .tile_sample_stride_num_frames + 1 , :, :])
1160-
1161- dec = torch .cat (result_row , dim = 2 )[:, :, :num_sample_frames ]
1162-
1163- if not return_dict :
1164- return (dec ,)
1165- return DecoderOutput (sample = dec )
11661118
11671119 def _decode (
11681120 self , z : torch .Tensor , temb : Optional [torch .Tensor ] = None , return_dict : bool = True
@@ -1171,13 +1123,13 @@ def _decode(
11711123 tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
11721124 tile_latent_min_width = self .tile_sample_stride_width // self .spatial_compression_ratio
11731125
1126+ if self .use_framewise_decoding and num_frames > tile_latent_min_num_frames :
1127+ return self ._temporal_tiled_decode (z , temb , return_dict = return_dict )
1128+
11741129 if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
11751130 return self .tiled_decode (z , temb , return_dict = return_dict )
11761131
1177- if self .use_framewise_decoding and num_frames > tile_latent_min_num_frames :
1178- dec = self ._temporal_tiled_decode (z , temb , return_dict = False )[0 ]
1179- else :
1180- dec = self .decoder (z , temb )
1132+ dec = self .decoder (z , temb )
11811133
11821134 if not return_dict :
11831135 return (dec ,)
@@ -1232,6 +1184,14 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
12321184 x / blend_extent
12331185 )
12341186 return b
1187+
1188+ def blend_t (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ) -> torch .Tensor :
1189+ blend_extent = min (a .shape [- 3 ], b .shape [- 3 ], blend_extent )
1190+ for x in range (blend_extent ):
1191+ b [:, :, x , :, :] = a [:, :, - blend_extent + x , :, :] * (1 - x / blend_extent ) + b [:, :, x , :, :] * (
1192+ x / blend_extent
1193+ )
1194+ return b
12351195
12361196 def tiled_encode (self , x : torch .Tensor ) -> torch .Tensor :
12371197 r"""Encode a batch of images using a tiled encoder.
@@ -1261,17 +1221,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12611221 for i in range (0 , height , self .tile_sample_stride_height ):
12621222 row = []
12631223 for j in range (0 , width , self .tile_sample_stride_width ):
1264- if self .use_framewise_encoding :
1265- # TODO(aryan): requires investigation
1266- raise NotImplementedError (
1267- "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1268- "quality issues caused by splitting inference across frame dimension. If you believe this "
1269- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1270- )
1271- else :
1272- time = self .encoder (
1273- x [:, :, :, i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
1274- )
1224+ time = self .encoder (
1225+ x [:, :, :, i : i + self .tile_sample_min_height , j : j + self .tile_sample_min_width ]
1226+ )
12751227
12761228 row .append (time )
12771229 rows .append (row )
@@ -1327,17 +1279,9 @@ def tiled_decode(
13271279 for i in range (0 , height , tile_latent_stride_height ):
13281280 row = []
13291281 for j in range (0 , width , tile_latent_stride_width ):
1330- if self .use_framewise_decoding :
1331- # TODO(aryan): requires investigation
1332- raise NotImplementedError (
1333- "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1334- "quality issues caused by splitting inference across frame dimension. If you believe this "
1335- "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1336- )
1337- else :
1338- time = self .decoder (
1339- z [:, :, :, i : i + tile_latent_min_height , j : j + tile_latent_min_width ], temb
1340- )
1282+ time = self .decoder (
1283+ z [:, :, :, i : i + tile_latent_min_height , j : j + tile_latent_min_width ], temb
1284+ )
13411285
13421286 row .append (time )
13431287 rows .append (row )
@@ -1362,6 +1306,72 @@ def tiled_decode(
13621306
13631307 return DecoderOutput (sample = dec )
13641308
1309+ def _temporal_tiled_encode (self , x : torch .Tensor ) -> AutoencoderKLOutput :
1310+ batch_size , num_channels , num_frames , height , width = x .shape
1311+ latent_num_frames = (num_frames - 1 ) // self .temporal_compression_ratio + 1
1312+
1313+ tile_latent_min_num_frames = self .tile_sample_min_num_frames // self .temporal_compression_ratio
1314+ tile_latent_stride_num_frames = self .tile_sample_stride_num_frames // self .temporal_compression_ratio
1315+ blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
1316+
1317+ row = []
1318+ for i in range (0 , num_frames , self .tile_sample_stride_num_frames ):
1319+ tile = x [:, :, i : i + self .tile_sample_min_num_frames + 1 , :, :]
1320+ if self .use_tiling and (height > self .tile_sample_min_height or width > self .tile_sample_min_width ):
1321+ tile = self .tiled_encode (tile )
1322+ else :
1323+ tile = self .encoder (tile )
1324+ if i > 0 :
1325+ tile = tile [:, :, 1 :, :, :]
1326+ row .append (tile )
1327+
1328+ result_row = []
1329+ for i , tile in enumerate (row ):
1330+ if i > 0 :
1331+ tile = self .blend_t (row [i - 1 ], tile , blend_num_frames )
1332+ result_row .append (tile [:, :, :tile_latent_stride_num_frames , :, :])
1333+ else :
1334+ result_row .append (tile [:, :, : tile_latent_stride_num_frames + 1 , :, :])
1335+
1336+ enc = torch .cat (result_row , dim = 2 )[:, :, :latent_num_frames ]
1337+ return enc
1338+
1339+ def _temporal_tiled_decode (self , z : torch .Tensor , temb : Optional [torch .Tensor ], return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1340+ batch_size , num_channels , num_frames , height , width = z .shape
1341+ num_sample_frames = (num_frames - 1 ) * self .temporal_compression_ratio + 1
1342+
1343+ tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
1344+ tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
1345+ tile_latent_min_num_frames = self .tile_sample_min_num_frames // self .temporal_compression_ratio
1346+ tile_latent_stride_num_frames = self .tile_sample_stride_num_frames // self .temporal_compression_ratio
1347+ blend_num_frames = self .tile_sample_min_num_frames - self .tile_sample_stride_num_frames
1348+
1349+ row = []
1350+ for i in range (0 , num_frames , tile_latent_stride_num_frames ):
1351+ tile = z [:, :, i : i + tile_latent_min_num_frames + 1 , :, :]
1352+ if self .use_tiling and (tile .shape [- 1 ] > tile_latent_min_width or tile .shape [- 2 ] > tile_latent_min_height ):
1353+ decoded = self .tiled_decode (tile , temb , return_dict = True ).sample
1354+ else :
1355+ decoded = self .decoder (tile , temb )
1356+ if i > 0 :
1357+ decoded = decoded [:, :, :- 1 , :, :]
1358+ row .append (decoded )
1359+
1360+ result_row = []
1361+ for i , tile in enumerate (row ):
1362+ if i > 0 :
1363+ tile = self .blend_t (row [i - 1 ], tile , blend_num_frames )
1364+ tile = tile [:, :, : self .tile_sample_stride_num_frames , :, :]
1365+ result_row .append (tile )
1366+ else :
1367+ result_row .append (tile [:, :, :self .tile_sample_stride_num_frames + 1 , :, :])
1368+
1369+ dec = torch .cat (result_row , dim = 2 )[:, :, :num_sample_frames ]
1370+
1371+ if not return_dict :
1372+ return (dec ,)
1373+ return DecoderOutput (sample = dec )
1374+
13651375 def forward (
13661376 self ,
13671377 sample : torch .Tensor ,
0 commit comments