Skip to content

Commit d76232d

Browse files
committed
address review comments
1 parent 5196b2a commit d76232d

File tree

6 files changed

+30
-69
lines changed

6 files changed

+30
-69
lines changed

scripts/convert_ltx_to_diffusers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
4949
# common
5050
"conv_shortcut": "conv_shortcut.conv",
5151
"res_blocks": "resnets",
52+
"norm3.norm": "norm3",
5253
"per_channel_statistics.mean-of-means": "latents_mean",
5354
"per_channel_statistics.std-of-means": "latents_std",
5455
}

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..activations import get_activation
2424
from ..modeling_outputs import AutoencoderKLOutput
2525
from ..modeling_utils import ModelMixin
26-
from ..normalization import LayerNormNd, RMSNormNd
26+
from ..normalization import RMSNorm
2727
from .vae import DecoderOutput, DiagonalGaussianDistribution
2828

2929

@@ -117,12 +117,12 @@ def __init__(
117117

118118
self.nonlinearity = get_activation(non_linearity)
119119

120-
self.norm1 = RMSNormNd(dim=in_channels, eps=1e-8, elementwise_affine=elementwise_affine, channel_dim=1)
120+
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
121121
self.conv1 = LTXCausalConv3d(
122122
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
123123
)
124124

125-
self.norm2 = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=elementwise_affine, channel_dim=1)
125+
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
126126
self.dropout = nn.Dropout(dropout)
127127
self.conv2 = LTXCausalConv3d(
128128
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
@@ -131,25 +131,25 @@ def __init__(
131131
self.norm3 = None
132132
self.conv_shortcut = None
133133
if in_channels != out_channels:
134-
self.norm3 = LayerNormNd(in_channels, eps=eps, elementwise_affine=True, bias=True, channel_dim=1)
134+
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
135135
self.conv_shortcut = LTXCausalConv3d(
136136
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
137137
)
138138

139139
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
140140
hidden_states = inputs
141141

142-
hidden_states = self.norm1(hidden_states)
142+
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
143143
hidden_states = self.nonlinearity(hidden_states)
144144
hidden_states = self.conv1(hidden_states)
145145

146-
hidden_states = self.norm2(hidden_states)
146+
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
147147
hidden_states = self.nonlinearity(hidden_states)
148148
hidden_states = self.dropout(hidden_states)
149149
hidden_states = self.conv2(hidden_states)
150150

151151
if self.norm3 is not None:
152-
inputs = self.norm3(inputs)
152+
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
153153

154154
if self.conv_shortcut is not None:
155155
inputs = self.conv_shortcut(inputs)
@@ -545,7 +545,7 @@ def __init__(
545545
)
546546

547547
# out
548-
self.norm_out = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=False, channel_dim=1)
548+
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
549549
self.conv_act = nn.SiLU()
550550
self.conv_out = LTXCausalConv3d(
551551
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
@@ -589,7 +589,7 @@ def create_forward(*inputs):
589589

590590
hidden_states = self.mid_block(hidden_states)
591591

592-
hidden_states = self.norm_out(hidden_states)
592+
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
593593
hidden_states = self.conv_act(hidden_states)
594594
hidden_states = self.conv_out(hidden_states)
595595

@@ -675,7 +675,7 @@ def __init__(
675675
self.up_blocks.append(up_block)
676676

677677
# out
678-
self.norm_out = RMSNormNd(dim=out_channels, eps=1e-8, elementwise_affine=False, channel_dim=1)
678+
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
679679
self.conv_act = nn.SiLU()
680680
self.conv_out = LTXCausalConv3d(
681681
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
@@ -704,7 +704,7 @@ def create_forward(*inputs):
704704
for up_block in self.up_blocks:
705705
hidden_states = up_block(hidden_states)
706706

707-
hidden_states = self.norm_out(hidden_states)
707+
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
708708
hidden_states = self.conv_act(hidden_states)
709709
hidden_states = self.conv_out(hidden_states)
710710

src/diffusers/models/normalization.py

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import numbers
17-
from typing import Dict, List, Optional, Tuple, Union
17+
from typing import Dict, Optional, Tuple
1818

1919
import torch
2020
import torch.nn as nn
@@ -567,54 +567,3 @@ def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
567567

568568
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
569569
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
570-
571-
572-
class LayerNormNd(nn.Module):
573-
def __init__(
574-
self,
575-
normalized_shape: Union[int, List[int], Tuple[int], torch.Size],
576-
eps: float = 1e-5,
577-
elementwise_affine: bool = True,
578-
bias: bool = True,
579-
device=None,
580-
dtype=None,
581-
channel_dim: int = -1,
582-
) -> None:
583-
super().__init__()
584-
585-
self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine, bias, device, dtype)
586-
self.channel_dim = channel_dim
587-
588-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
589-
if self.channel_dim != -1:
590-
hidden_states = hidden_states.movedim(self.channel_dim, -1)
591-
hidden_states = self.norm(hidden_states)
592-
hidden_states = hidden_states.movedim(-1, self.channel_dim)
593-
else:
594-
hidden_states = self.norm(hidden_states)
595-
596-
return hidden_states
597-
598-
599-
class RMSNormNd(nn.Module):
600-
def __init__(
601-
self,
602-
dim: int,
603-
eps: float,
604-
elementwise_affine: bool = True,
605-
channel_dim: int = -1,
606-
) -> None:
607-
super().__init__()
608-
609-
self.norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
610-
self.channel_dim = channel_dim
611-
612-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
613-
if self.channel_dim != -1:
614-
hidden_states = hidden_states.movedim(self.channel_dim, -1)
615-
hidden_states = self.norm(hidden_states)
616-
hidden_states = hidden_states.movedim(-1, self.channel_dim)
617-
else:
618-
hidden_states = self.norm(hidden_states)
619-
620-
return hidden_states

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ def __call__(
6262
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
6363
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
6464

65-
use_rotary_emb = False
6665
if encoder_hidden_states is None:
6766
encoder_hidden_states = hidden_states
68-
use_rotary_emb = True
6967

7068
query = attn.to_q(hidden_states)
7169
key = attn.to_k(encoder_hidden_states)
@@ -74,7 +72,7 @@ def __call__(
7472
query = attn.norm_q(query)
7573
key = attn.norm_k(key)
7674

77-
if image_rotary_emb is not None and use_rotary_emb:
75+
if image_rotary_emb is not None:
7876
query = apply_rotary_emb(query, image_rotary_emb)
7977
key = apply_rotary_emb(key, image_rotary_emb)
8078

@@ -255,7 +253,7 @@ def forward(
255253
attn_hidden_states = self.attn2(
256254
hidden_states,
257255
encoder_hidden_states=encoder_hidden_states,
258-
image_rotary_emb=image_rotary_emb,
256+
image_rotary_emb=None,
259257
attention_mask=encoder_attention_mask,
260258
)
261259
hidden_states = hidden_states + attn_hidden_states

src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,6 @@ def __call__(
774774
timestep, _ = timestep.chunk(2)
775775

776776
# compute the previous noisy sample x_t -> x_t-1
777-
# ============= TODO(aryan): needs a look by YiYi
778777
noise_pred = self._unpack_latents(
779778
noise_pred,
780779
latent_num_frames,
@@ -800,7 +799,6 @@ def __call__(
800799
latents = self._pack_latents(
801800
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
802801
)
803-
# =============
804802

805803
if callback_on_step_end is not None:
806804
callback_kwargs = {}

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,21 @@ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
183183
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
184184

185185
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
186+
r"""
187+
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
188+
value.
189+
190+
Reference:
191+
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
192+
193+
Args:
194+
t (`torch.Tensor`):
195+
A tensor of timesteps to be stretched and shifted.
196+
197+
Returns:
198+
`torch.Tensor`:
199+
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
200+
"""
186201
one_minus_z = 1 - t
187202
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
188203
stretched_t = 1 - (one_minus_z / scale_factor)

0 commit comments

Comments
 (0)