Skip to content

Commit 1141cdd

Browse files
committed
up
1 parent 1cfd2ee commit 1141cdd

File tree

1 file changed

+2
-22
lines changed

1 file changed

+2
-22
lines changed

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
2727
from ..attention import FeedForward
2828
from ..attention_processor import Attention
@@ -102,7 +102,6 @@ def __init__(
102102
patch_size: int = 1,
103103
patch_size_t: int = 1,
104104
theta: float = 10000.0,
105-
_causal_rope_fix: bool = False,
106105
) -> None:
107106
super().__init__()
108107

@@ -113,7 +112,6 @@ def __init__(
113112
self.patch_size = patch_size
114113
self.patch_size_t = patch_size_t
115114
self.theta = theta
116-
self._causal_rope_fix = _causal_rope_fix
117115

118116
def _prepare_video_coords(
119117
self,
@@ -133,22 +131,10 @@ def _prepare_video_coords(
133131
grid = torch.stack(grid, dim=0)
134132
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
135133

136-
if isinstance(rope_interpolation_scale, tuple):
137-
# This will be deprecated in v0.34.0
134+
if rope_interpolation_scale is not None:
138135
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
139136
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
140137
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
141-
else:
142-
if not self._causal_rope_fix:
143-
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
144-
else:
145-
grid[:, 0:1] = (
146-
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
147-
* self.patch_size_t
148-
/ self.base_num_frames
149-
)
150-
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
151-
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
152138

153139
grid = grid.flatten(2, 4).transpose(1, 2)
154140

@@ -363,7 +349,6 @@ def __init__(
363349
caption_channels: int = 4096,
364350
attention_bias: bool = True,
365351
attention_out_bias: bool = True,
366-
_causal_rope_fix: bool = False,
367352
) -> None:
368353
super().__init__()
369354

@@ -385,7 +370,6 @@ def __init__(
385370
patch_size=patch_size,
386371
patch_size_t=patch_size_t,
387372
theta=10000.0,
388-
_causal_rope_fix=_causal_rope_fix,
389373
)
390374

391375
self.transformer_blocks = nn.ModuleList(
@@ -441,10 +425,6 @@ def forward(
441425
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
442426
)
443427

444-
if not isinstance(rope_interpolation_scale, torch.Tensor):
445-
msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
446-
deprecate("rope_interpolation_scale", "0.34.0", msg)
447-
448428
image_rotary_emb = self.rope(
449429
hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords
450430
)

0 commit comments

Comments
 (0)