Skip to content

Commit 322d03d

Browse files
committed
make style
1 parent 3de88be commit 322d03d

File tree

3 files changed

+52
-33
lines changed

3 files changed

+52
-33
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,11 @@
470470
"LDMTextToImagePipeline",
471471
"LEditsPPPipelineStableDiffusion",
472472
"LEditsPPPipelineStableDiffusionXL",
473+
"LTXConditionInfinitePipeline",
473474
"LTXConditionPipeline",
474475
"LTXImageToVideoPipeline",
475476
"LTXLatentUpsamplePipeline",
476477
"LTXPipeline",
477-
"LTXConditionInfinitePipeline",
478478
"Lumina2Pipeline",
479479
"Lumina2Text2ImgPipeline",
480480
"LuminaPipeline",
@@ -1056,7 +1056,6 @@
10561056
EasyAnimatePipeline,
10571057
FluxControlImg2ImgPipeline,
10581058
FluxControlInpaintPipeline,
1059-
LTXConditionInfinitePipeline,
10601059
FluxControlNetImg2ImgPipeline,
10611060
FluxControlNetInpaintPipeline,
10621061
FluxControlNetPipeline,
@@ -1109,6 +1108,7 @@
11091108
LDMTextToImagePipeline,
11101109
LEditsPPPipelineStableDiffusion,
11111110
LEditsPPPipelineStableDiffusionXL,
1111+
LTXConditionInfinitePipeline,
11121112
LTXConditionPipeline,
11131113
LTXImageToVideoPipeline,
11141114
LTXLatentUpsamplePipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,13 @@
672672
LEditsPPPipelineStableDiffusion,
673673
LEditsPPPipelineStableDiffusionXL,
674674
)
675-
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline, LTXConditionInfinitePipeline
675+
from .ltx import (
676+
LTXConditionInfinitePipeline,
677+
LTXConditionPipeline,
678+
LTXImageToVideoPipeline,
679+
LTXLatentUpsamplePipeline,
680+
LTXPipeline,
681+
)
676682
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
677683
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
678684
from .marigold import (

src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from transformers import T5EncoderModel, T5TokenizerFast
2020

2121
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
22-
from ...image_processor import PipelineImageInput
2322
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
2423
from ...models.autoencoders import AutoencoderKLLTXVideo
2524
from ...models.transformers import LTXVideoTransformer3DModel
@@ -559,7 +558,7 @@ def _extract_spatial_tile(self, latents, v_start, v_end, h_start, h_end):
559558
"""Extract spatial tiles from all inputs for a given spatial region."""
560559
tile_latents = latents[:, :, :, v_start:v_end, h_start:h_end]
561560
return tile_latents
562-
561+
563562
def _select_latents(self, latents: torch.Tensor, start_index: int, end_index: int) -> torch.Tensor:
564563
num_frames = latents.shape[2]
565564
start_idx = num_frames + start_index if start_index < 0 else start_index
@@ -570,11 +569,9 @@ def _select_latents(self, latents: torch.Tensor, start_index: int, end_index: in
570569
start_idx = min(start_idx, end_idx)
571570
latents = latents[:, :, start_idx : end_idx + 1, :, :].clone()
572571
return latents
573-
572+
574573
@staticmethod
575-
def _create_spatial_weights(
576-
latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap
577-
):
574+
def _create_spatial_weights(latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap):
578575
"""Create blending weights for spatial tiles."""
579576
tile_weights = torch.ones_like(latents)
580577

@@ -658,7 +655,7 @@ def prepare_latents(
658655
latent_width = width // self.vae_spatial_compression_ratio
659656
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
660657
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
661-
658+
662659
if latents is not None:
663660
if latents.shape != shape:
664661
raise ValueError(
@@ -678,11 +675,7 @@ def prepare_latents(
678675
device=device,
679676
)
680677
video_ids = self._scale_video_ids(
681-
video_ids,
682-
self.vae_spatial_compression_ratio,
683-
self.vae_temporal_compression_ratio,
684-
0,
685-
device
678+
video_ids, self.vae_spatial_compression_ratio, self.vae_temporal_compression_ratio, 0, device
686679
)
687680

688681
return latents, video_ids
@@ -857,7 +850,9 @@ def __call__(
857850
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
858851
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
859852
if horizontal_tiles > 1 or vertical_tiles > 1:
860-
raise ValueError("Setting `horizontal_tiles` or `vertical_tiles` to a value greater than 0 is not supported yet.")
853+
raise ValueError(
854+
"Setting `horizontal_tiles` or `vertical_tiles` to a value greater than 0 is not supported yet."
855+
)
861856

862857
# 1. Check inputs. Raise error if not correct
863858
self.check_inputs(
@@ -967,11 +962,14 @@ def __call__(
967962
first_tile_out_latents = None
968963

969964
for index_temporal_tile, (start_index, end_index) in enumerate(
970-
zip(range(0, temporal_range_max, temporal_range_step),
971-
range(temporal_tile_size, temporal_range_max, temporal_range_step)
965+
zip(
966+
range(0, temporal_range_max, temporal_range_step),
967+
range(temporal_tile_size, temporal_range_max, temporal_range_step),
972968
)
973969
):
974-
latent_chunk = self._select_latents(tile_latents, start_index, min(end_index - 1, tile_latents.shape[2] - 1))
970+
latent_chunk = self._select_latents(
971+
tile_latents, start_index, min(end_index - 1, tile_latents.shape[2] - 1)
972+
)
975973
latent_tile_num_frames = latent_chunk.shape[2]
976974

977975
if start_index > 0:
@@ -981,12 +979,14 @@ def __call__(
981979
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
982980

983981
conditioning_mask = torch.zeros(
984-
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
982+
(batch_size, total_latent_num_frames),
983+
dtype=torch.float32,
984+
device=device,
985985
)
986986
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
987987
else:
988988
total_latent_num_frames = latent_tile_num_frames
989-
989+
990990
latent_chunk = self._pack_latents(
991991
latent_chunk,
992992
self.transformer_spatial_patch_size,
@@ -1002,29 +1002,31 @@ def __call__(
10021002
patch_size=self.transformer_spatial_patch_size,
10031003
device=device,
10041004
)
1005-
1005+
10061006
if start_index > 0:
10071007
conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0])
10081008
conditioning_mask_model_input = (
10091009
torch.cat([conditioning_mask, conditioning_mask])
10101010
if self.do_classifier_free_guidance
10111011
else conditioning_mask
10121012
)
1013-
1013+
10141014
video_ids = self._scale_video_ids(
10151015
video_ids,
10161016
scale_factor=self.vae_spatial_compression_ratio,
10171017
scale_factor_t=self.vae_temporal_compression_ratio,
10181018
frame_index=0,
1019-
device=device
1019+
device=device,
10201020
)
10211021
video_ids = video_ids.float()
10221022
video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate)
10231023
if self.do_classifier_free_guidance:
10241024
video_ids = torch.cat([video_ids, video_ids], dim=0)
10251025

10261026
# Set timesteps
1027-
inner_timesteps, inner_num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1027+
inner_timesteps, inner_num_inference_steps = retrieve_timesteps(
1028+
self.scheduler, num_inference_steps, device, timesteps
1029+
)
10281030
sigmas = self.scheduler.sigmas
10291031
num_warmup_steps = max(len(inner_timesteps) - inner_num_inference_steps * self.scheduler.order, 0)
10301032
self._num_timesteps = len(inner_timesteps)
@@ -1035,7 +1037,9 @@ def __call__(
10351037
continue
10361038

10371039
self._current_timestep = t
1038-
latent_model_input = torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
1040+
latent_model_input = (
1041+
torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
1042+
)
10391043
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
10401044
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
10411045
if start_index > 0:
@@ -1054,7 +1058,9 @@ def __call__(
10541058

10551059
if self.do_classifier_free_guidance:
10561060
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1057-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1061+
noise_pred = noise_pred_uncond + self.guidance_scale * (
1062+
noise_pred_text - noise_pred_uncond
1063+
)
10581064
timestep, _ = timestep.chunk(2)
10591065

10601066
if self.guidance_rescale > 0:
@@ -1082,7 +1088,9 @@ def __call__(
10821088
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
10831089

10841090
# call the callback, if provided
1085-
if i == len(inner_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1091+
if i == len(inner_timesteps) - 1 or (
1092+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1093+
):
10861094
progress_bar.update()
10871095

10881096
if XLA_AVAILABLE:
@@ -1096,13 +1104,15 @@ def __call__(
10961104
self.transformer_spatial_patch_size,
10971105
self.transformer_temporal_patch_size,
10981106
)
1099-
1107+
11001108
if start_index == 0:
11011109
first_tile_out_latents = latent_chunk.clone()
11021110
else:
11031111
# We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent
1104-
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames + 1:, :, :]
1105-
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(latent_chunk, first_tile_out_latents, adain_factor)
1112+
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames + 1 :, :, :]
1113+
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(
1114+
latent_chunk, first_tile_out_latents, adain_factor
1115+
)
11061116

11071117
alpha = torch.linspace(1, 0, temporal_overlap + 1, device=latent_chunk.device)[1:-1]
11081118
alpha = alpha.view(1, 1, -1, 1, 1)
@@ -1111,14 +1121,17 @@ def __call__(
11111121
t_minus_one = temporal_overlap - 1
11121122
parts = [
11131123
tile_out_latents[:, :, :-t_minus_one],
1114-
alpha * tile_out_latents[:, :, -t_minus_one:] + (1 - alpha) * latent_chunk[:, :, :t_minus_one],
1124+
alpha * tile_out_latents[:, :, -t_minus_one:]
1125+
+ (1 - alpha) * latent_chunk[:, :, :t_minus_one],
11151126
latent_chunk[:, :, t_minus_one:],
11161127
]
11171128
latent_chunk = torch.cat(parts, dim=2)
11181129

11191130
tile_out_latents = latent_chunk.clone()
11201131

1121-
tile_weights = self._create_spatial_weights(tile_out_latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap)
1132+
tile_weights = self._create_spatial_weights(
1133+
tile_out_latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap
1134+
)
11221135
final_latents[:, :, :, v_start:v_end, h_start:h_end] += latent_chunk * tile_weights
11231136
weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights
11241137

0 commit comments

Comments
 (0)