Skip to content

Commit c5e6d62

Browse files
committed
run make style
1 parent e79162c commit c5e6d62

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ def __init__(
10151015
# The minimal distance between two spatial tiles
10161016
self.tile_sample_stride_height = 448
10171017
self.tile_sample_stride_width = 448
1018-
self.tile_sample_stride_num_frames = 8
1018+
self.tile_sample_stride_num_frames = 8
10191019

10201020
def _set_gradient_checkpointing(self, module, value=False):
10211021
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
@@ -1185,7 +1185,7 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
11851185
x / blend_extent
11861186
)
11871187
return b
1188-
1188+
11891189
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
11901190
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
11911191
for x in range(blend_extent):
@@ -1280,9 +1280,7 @@ def tiled_decode(
12801280
for i in range(0, height, tile_latent_stride_height):
12811281
row = []
12821282
for j in range(0, width, tile_latent_stride_width):
1283-
time = self.decoder(
1284-
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
1285-
)
1283+
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
12861284

12871285
row.append(time)
12881286
rows.append(row)
@@ -1337,7 +1335,9 @@ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
13371335
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
13381336
return enc
13391337

1340-
def _temporal_tiled_decode(self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1338+
def _temporal_tiled_decode(
1339+
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
1340+
) -> Union[DecoderOutput, torch.Tensor]:
13411341
batch_size, num_channels, num_frames, height, width = z.shape
13421342
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
13431343

@@ -1365,7 +1365,7 @@ def _temporal_tiled_decode(self, z: torch.Tensor, temb: Optional[torch.Tensor],
13651365
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
13661366
result_row.append(tile)
13671367
else:
1368-
result_row.append(tile[:, :, :self.tile_sample_stride_num_frames + 1, :, :])
1368+
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
13691369

13701370
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
13711371

0 commit comments

Comments
 (0)