2323from ..activations import get_activation
2424from ..modeling_outputs import AutoencoderKLOutput
2525from ..modeling_utils import ModelMixin
26- from ..normalization import LayerNormNd , RMSNormNd
26+ from ..normalization import RMSNorm
2727from .vae import DecoderOutput , DiagonalGaussianDistribution
2828
2929
@@ -117,12 +117,12 @@ def __init__(
117117
118118 self .nonlinearity = get_activation (non_linearity )
119119
120- self .norm1 = RMSNormNd ( dim = in_channels , eps = 1e-8 , elementwise_affine = elementwise_affine , channel_dim = 1 )
120+ self .norm1 = RMSNorm ( in_channels , eps = 1e-8 , elementwise_affine = elementwise_affine )
121121 self .conv1 = LTXCausalConv3d (
122122 in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , is_causal = is_causal
123123 )
124124
125- self .norm2 = RMSNormNd ( dim = out_channels , eps = 1e-8 , elementwise_affine = elementwise_affine , channel_dim = 1 )
125+ self .norm2 = RMSNorm ( out_channels , eps = 1e-8 , elementwise_affine = elementwise_affine )
126126 self .dropout = nn .Dropout (dropout )
127127 self .conv2 = LTXCausalConv3d (
128128 in_channels = out_channels , out_channels = out_channels , kernel_size = 3 , is_causal = is_causal
@@ -131,25 +131,25 @@ def __init__(
131131 self .norm3 = None
132132 self .conv_shortcut = None
133133 if in_channels != out_channels :
134- self .norm3 = LayerNormNd (in_channels , eps = eps , elementwise_affine = True , bias = True , channel_dim = 1 )
134+ self .norm3 = nn . LayerNorm (in_channels , eps = eps , elementwise_affine = True , bias = True )
135135 self .conv_shortcut = LTXCausalConv3d (
136136 in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 1 , is_causal = is_causal
137137 )
138138
139139 def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
140140 hidden_states = inputs
141141
142- hidden_states = self .norm1 (hidden_states )
142+ hidden_states = self .norm1 (hidden_states . movedim ( 1 , - 1 )). movedim ( - 1 , 1 )
143143 hidden_states = self .nonlinearity (hidden_states )
144144 hidden_states = self .conv1 (hidden_states )
145145
146- hidden_states = self .norm2 (hidden_states )
146+ hidden_states = self .norm2 (hidden_states . movedim ( 1 , - 1 )). movedim ( - 1 , 1 )
147147 hidden_states = self .nonlinearity (hidden_states )
148148 hidden_states = self .dropout (hidden_states )
149149 hidden_states = self .conv2 (hidden_states )
150150
151151 if self .norm3 is not None :
152- inputs = self .norm3 (inputs )
152+ inputs = self .norm3 (inputs . movedim ( 1 , - 1 )). movedim ( - 1 , 1 )
153153
154154 if self .conv_shortcut is not None :
155155 inputs = self .conv_shortcut (inputs )
@@ -545,7 +545,7 @@ def __init__(
545545 )
546546
547547 # out
548- self .norm_out = RMSNormNd ( dim = out_channels , eps = 1e-8 , elementwise_affine = False , channel_dim = 1 )
548+ self .norm_out = RMSNorm ( out_channels , eps = 1e-8 , elementwise_affine = False )
549549 self .conv_act = nn .SiLU ()
550550 self .conv_out = LTXCausalConv3d (
551551 in_channels = output_channel , out_channels = out_channels + 1 , kernel_size = 3 , stride = 1 , is_causal = is_causal
@@ -589,7 +589,7 @@ def create_forward(*inputs):
589589
590590 hidden_states = self .mid_block (hidden_states )
591591
592- hidden_states = self .norm_out (hidden_states )
592+ hidden_states = self .norm_out (hidden_states . movedim ( 1 , - 1 )). movedim ( - 1 , 1 )
593593 hidden_states = self .conv_act (hidden_states )
594594 hidden_states = self .conv_out (hidden_states )
595595
@@ -675,7 +675,7 @@ def __init__(
675675 self .up_blocks .append (up_block )
676676
677677 # out
678- self .norm_out = RMSNormNd ( dim = out_channels , eps = 1e-8 , elementwise_affine = False , channel_dim = 1 )
678+ self .norm_out = RMSNorm ( out_channels , eps = 1e-8 , elementwise_affine = False )
679679 self .conv_act = nn .SiLU ()
680680 self .conv_out = LTXCausalConv3d (
681681 in_channels = output_channel , out_channels = self .out_channels , kernel_size = 3 , stride = 1 , is_causal = is_causal
@@ -704,7 +704,7 @@ def create_forward(*inputs):
704704 for up_block in self .up_blocks :
705705 hidden_states = up_block (hidden_states )
706706
707- hidden_states = self .norm_out (hidden_states )
707+ hidden_states = self .norm_out (hidden_states . movedim ( 1 , - 1 )). movedim ( - 1 , 1 )
708708 hidden_states = self .conv_act (hidden_states )
709709 hidden_states = self .conv_out (hidden_states )
710710
0 commit comments