Skip to content

Commit 353728a

Browse files
committed
up
1 parent 7c2151f commit 353728a

File tree

5 files changed

+141
-85
lines changed

5 files changed

+141
-85
lines changed

scripts/convert_ltx_to_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,15 +350,15 @@ def get_args():
350350

351351
if args.version == "0.9.5":
352352
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
353-
else:
353+
else:
354354
scheduler = FlowMatchEulerDiscreteScheduler(
355355
use_dynamic_shifting=True,
356356
base_shift=0.95,
357357
max_shift=2.05,
358358
base_image_seq_len=1024,
359359
max_image_seq_len=4096,
360360
shift_terminal=0.1,
361-
)
361+
)
362362

363363
pipe = LTXPipeline(
364364
scheduler=scheduler,

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,8 +1144,16 @@ def __init__(
11441144
self.register_buffer("latents_mean", latents_mean, persistent=True)
11451145
self.register_buffer("latents_std", latents_std, persistent=True)
11461146

1147-
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) if spatial_compression_ratio is None else spatial_compression_ratio
1148-
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) if temporal_compression_ratio is None else temporal_compression_ratio
1147+
self.spatial_compression_ratio = (
1148+
patch_size * 2 ** sum(spatio_temporal_scaling)
1149+
if spatial_compression_ratio is None
1150+
else spatial_compression_ratio
1151+
)
1152+
self.temporal_compression_ratio = (
1153+
patch_size_t * 2 ** sum(spatio_temporal_scaling)
1154+
if temporal_compression_ratio is None
1155+
else temporal_compression_ratio
1156+
)
11491157

11501158
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
11511159
# to perform decoding of a single video latent at a time.

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,16 @@ def __init__(
115115
self.theta = theta
116116
self._causal_rope_fix = _causal_rope_fix
117117

118-
119-
def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, width: int, rope_interpolation_scale: Tuple[torch.Tensor, float, float], device: torch.device) -> torch.Tensor:
118+
def _prepare_video_coords(
119+
self,
120+
batch_size: int,
121+
num_frames: int,
122+
height: int,
123+
width: int,
124+
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
125+
frame_rate: float,
126+
device: torch.device,
127+
) -> torch.Tensor:
120128
# Always compute rope in fp32
121129
grid_h = torch.arange(height, dtype=torch.float32, device=device)
122130
grid_w = torch.arange(width, dtype=torch.float32, device=device)
@@ -132,9 +140,7 @@ def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, w
132140
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
133141
else:
134142
if not self._causal_rope_fix:
135-
grid[:, 0:1] = (
136-
grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
137-
)
143+
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
138144
else:
139145
grid[:, 0:1] = (
140146
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
@@ -145,9 +151,8 @@ def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, w
145151
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
146152

147153
grid = grid.flatten(2, 4).transpose(1, 2)
148-
154+
149155
return grid
150-
151156

152157
def forward(
153158
self,
@@ -162,14 +167,22 @@ def forward(
162167
batch_size = hidden_states.size(0)
163168

164169
if video_coords is None:
165-
grid = self._prepare_video_coords(batch_size, num_frames, height, width, rope_interpolation_scale=rope_interpolation_scale, device=hidden_states.device)
170+
grid = self._prepare_video_coords(
171+
batch_size,
172+
num_frames,
173+
height,
174+
width,
175+
rope_interpolation_scale=rope_interpolation_scale,
176+
frame_rate=frame_rate,
177+
device=hidden_states.device,
178+
)
166179
else:
167180
grid = torch.stack(
168181
[
169-
video_coords[:, 0] / self.base_num_frames,
170-
video_coords[:, 1] / self.base_height,
171-
video_coords[:, 2] / self.base_width
172-
],
182+
video_coords[:, 0] / self.base_num_frames,
183+
video_coords[:, 1] / self.base_height,
184+
video_coords[:, 2] / self.base_width,
185+
],
173186
dim=-1,
174187
)
175188

@@ -432,8 +445,9 @@ def forward(
432445
msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
433446
deprecate("rope_interpolation_scale", "0.34.0", msg)
434447

435-
436-
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords)
448+
image_rotary_emb = self.rope(
449+
hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords
450+
)
437451

438452
# convert encoder_attention_mask to a bias the same way we do for attention_mask
439453
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:

0 commit comments

Comments
 (0)