@@ -138,43 +138,53 @@ def __init__(
138138 in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 1 , is_causal = is_causal
139139 )
140140
141- self .scale1 = None
142- self .scale2 = None
141+ self .per_channel_scale1 = None
142+ self .per_channel_scale2 = None
143143 if inject_noise :
144- self .scale1 = nn .Parameter (torch .zeros (in_channels , 1 , 1 ))
145- self .scale2 = nn .Parameter (torch .zeros (in_channels , 1 , 1 ))
144+ self .per_channel_scale1 = nn .Parameter (torch .zeros (in_channels , 1 , 1 ))
145+ self .per_channel_scale2 = nn .Parameter (torch .zeros (in_channels , 1 , 1 ))
146146
147147 self .scale_shift_table = None
148148 if timestep_conditioning :
149149 self .scale_shift_table = nn .Parameter (torch .randn (4 , in_channels ) / in_channels ** 0.5 )
150150
151- def forward (self , inputs : torch .Tensor , temb : Optional [torch .Tensor ] = None ) -> torch .Tensor :
151+ def forward (
152+ self , inputs : torch .Tensor , temb : Optional [torch .Tensor ] = None , generator : Optional [torch .Generator ] = None
153+ ) -> torch .Tensor :
152154 hidden_states = inputs
153155
154156 hidden_states = self .norm1 (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
155- scale_1 , shift_1 , scale_2 , shift_2 = self .scale_shift_table .unbind (dim = 0 )
157+
158+ if self .scale_shift_table is not None :
159+ temb = temb .unflatten (1 , (4 , - 1 )) + self .scale_shift_table [None , ..., None , None , None ]
160+ shift_1 , scale_1 , shift_2 , scale_2 = temb .unbind (dim = 1 )
161+ hidden_states = hidden_states * (1 + scale_1 ) + shift_1
156162
157163 hidden_states = self .nonlinearity (hidden_states )
158164 hidden_states = self .conv1 (hidden_states )
159165
160- if self .scale1 is not None :
166+ if self .per_channel_scale1 is not None :
161167 spatial_shape = hidden_states .shape [- 2 :]
162- spatial_noise = torch .randn (spatial_shape , device = hidden_states .device , dtype = hidden_states .dtype )
163- hidden_states = hidden_states + (spatial_noise * self .scale1 )[None , :, None , :, :]
168+ spatial_noise = torch .randn (
169+ spatial_shape , generator = generator , device = hidden_states .device , dtype = hidden_states .dtype
170+ )
171+ hidden_states = hidden_states + (spatial_noise * self .per_channel_scale1 )[None , :, None , :, :]
164172
165173 hidden_states = self .norm2 (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
166174
167175 if self .scale_shift_table is not None :
168- hidden_states = hidden_states * (1 + scale_1 ) + shift_1
176+ hidden_states = hidden_states * (1 + scale_2 ) + shift_2
169177
170178 hidden_states = self .nonlinearity (hidden_states )
171179 hidden_states = self .dropout (hidden_states )
172180 hidden_states = self .conv2 (hidden_states )
173181
174- if self .scale2 is not None :
182+ if self .per_channel_scale2 is not None :
175183 spatial_shape = hidden_states .shape [- 2 :]
176- spatial_noise = torch .randn (spatial_shape , device = hidden_states .device , dtype = hidden_states .dtype )
177- hidden_states = hidden_states + (spatial_noise * self .scale2 )[None , :, None , :, :]
184+ spatial_noise = torch .randn (
185+ spatial_shape , generator = generator , device = hidden_states .device , dtype = hidden_states .dtype
186+ )
187+ hidden_states = hidden_states + (spatial_noise * self .per_channel_scale2 )[None , :, None , :, :]
178188
179189 if self .norm3 is not None :
180190 inputs = self .norm3 (inputs .movedim (1 , - 1 )).movedim (- 1 , 1 )
@@ -318,7 +328,12 @@ def __init__(
318328
319329 self .gradient_checkpointing = False
320330
321- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
331+ def forward (
332+ self ,
333+ hidden_states : torch .Tensor ,
334+ temb : Optional [torch .Tensor ] = None ,
335+ generator : Optional [torch .Generator ] = None ,
336+ ) -> torch .Tensor :
322337 r"""Forward method of the `LTXDownBlock3D` class."""
323338
324339 for i , resnet in enumerate (self .resnets ):
@@ -330,16 +345,18 @@ def create_forward(*inputs):
330345
331346 return create_forward
332347
333- hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (resnet ), hidden_states )
348+ hidden_states = torch .utils .checkpoint .checkpoint (
349+ create_custom_forward (resnet ), hidden_states , temb , generator
350+ )
334351 else :
335- hidden_states = resnet (hidden_states )
352+ hidden_states = resnet (hidden_states , temb , generator )
336353
337354 if self .downsamplers is not None :
338355 for downsampler in self .downsamplers :
339356 hidden_states = downsampler (hidden_states )
340357
341358 if self .conv_out is not None :
342- hidden_states = self .conv_out (hidden_states )
359+ hidden_states = self .conv_out (hidden_states , temb , generator )
343360
344361 return hidden_states
345362
@@ -401,7 +418,12 @@ def __init__(
401418
402419 self .gradient_checkpointing = False
403420
404- def forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] = None ) -> torch .Tensor :
421+ def forward (
422+ self ,
423+ hidden_states : torch .Tensor ,
424+ temb : Optional [torch .Tensor ] = None ,
425+ generator : Optional [torch .Generator ] = None ,
426+ ) -> torch .Tensor :
405427 r"""Forward method of the `LTXMidBlock3D` class."""
406428
407429 if self .time_embedder is not None :
@@ -423,9 +445,11 @@ def create_forward(*inputs):
423445
424446 return create_forward
425447
426- hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (resnet ), hidden_states , temb )
448+ hidden_states = torch .utils .checkpoint .checkpoint (
449+ create_custom_forward (resnet ), hidden_states , temb , generator
450+ )
427451 else :
428- hidden_states = resnet (hidden_states , temb )
452+ hidden_states = resnet (hidden_states , temb , generator )
429453
430454 return hidden_states
431455
@@ -524,9 +548,14 @@ def __init__(
524548
525549 self .gradient_checkpointing = False
526550
527- def forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] = None ) -> torch .Tensor :
551+ def forward (
552+ self ,
553+ hidden_states : torch .Tensor ,
554+ temb : Optional [torch .Tensor ] = None ,
555+ generator : Optional [torch .Generator ] = None ,
556+ ) -> torch .Tensor :
528557 if self .conv_in is not None :
529- hidden_states = self .conv_in (hidden_states )
558+ hidden_states = self .conv_in (hidden_states , temb , generator )
530559
531560 if self .time_embedder is not None :
532561 temb = self .time_embedder (
@@ -551,9 +580,11 @@ def create_forward(*inputs):
551580
552581 return create_forward
553582
554- hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (resnet ), hidden_states )
583+ hidden_states = torch .utils .checkpoint .checkpoint (
584+ create_custom_forward (resnet ), hidden_states , temb , generator
585+ )
555586 else :
556- hidden_states = resnet (hidden_states )
587+ hidden_states = resnet (hidden_states , temb , generator )
557588
558589 return hidden_states
559590
@@ -746,6 +777,9 @@ def __init__(
746777 block_out_channels = tuple (reversed (block_out_channels ))
747778 spatio_temporal_scaling = tuple (reversed (spatio_temporal_scaling ))
748779 layers_per_block = tuple (reversed (layers_per_block ))
780+ inject_noise = tuple (reversed (inject_noise ))
781+ upsample_residual = tuple (reversed (upsample_residual ))
782+ upsample_factor = tuple (reversed (upsample_factor ))
749783 output_channel = block_out_channels [0 ]
750784
751785 self .conv_in = LTXCausalConv3d (
@@ -810,29 +844,31 @@ def create_forward(*inputs):
810844
811845 return create_forward
812846
813- hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (self .mid_block ), hidden_states )
847+ hidden_states = torch .utils .checkpoint .checkpoint (
848+ create_custom_forward (self .mid_block ), hidden_states , temb
849+ )
814850
815851 for up_block in self .up_blocks :
816- hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (up_block ), hidden_states )
852+ hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (up_block ), hidden_states , temb )
817853 else :
818- hidden_states = self .mid_block (hidden_states )
854+ hidden_states = self .mid_block (hidden_states , temb )
819855
820856 for up_block in self .up_blocks :
821- hidden_states = up_block (hidden_states )
857+ hidden_states = up_block (hidden_states , temb )
822858
823859 hidden_states = self .norm_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
824860
825861 if self .time_embedder is not None :
826- embedded_timestep = self .time_embedder (
862+ temb = self .time_embedder (
827863 timestep = temb .flatten (),
828864 resolution = None ,
829865 aspect_ratio = None ,
830866 batch_size = hidden_states .size (0 ),
831867 hidden_dtype = hidden_states .dtype ,
832868 )
833- embedded_timestep = embedded_timestep .view (hidden_states .size (0 ), - 1 , 1 , 1 , 1 ).unflatten (1 , (2 , - 1 ))
834- embedded_timestep = embedded_timestep + self .scale_shift_table [None , : , None , None , None ]
835- shift , scale = embedded_timestep .unbind (dim = 1 )
869+ temb = temb .view (hidden_states .size (0 ), - 1 , 1 , 1 , 1 ).unflatten (1 , (2 , - 1 ))
870+ temb = temb + self .scale_shift_table [None , ... , None , None , None ]
871+ shift , scale = temb .unbind (dim = 1 )
836872 hidden_states = hidden_states * (1 + scale ) + shift
837873
838874 hidden_states = self .conv_act (hidden_states )
@@ -902,7 +938,7 @@ def __init__(
902938 decoder_layers_per_block : Tuple [int , ...] = (4 , 3 , 3 , 3 , 4 ),
903939 spatio_temporal_scaling : Tuple [bool , ...] = (True , True , True , False ),
904940 decoder_spatio_temporal_scaling : Tuple [bool , ...] = (True , True , True , False ),
905- decoder_inject_noise : Tuple [bool , ...] = (False , False , False , False ),
941+ decoder_inject_noise : Tuple [bool , ...] = (False , False , False , False , False ),
906942 upsample_residual : Tuple [bool , ...] = (False , False , False , False ),
907943 upsample_factor : Tuple [int , ...] = (1 , 1 , 1 , 1 ),
908944 timestep_conditioning : bool = False ,
@@ -1078,13 +1114,15 @@ def encode(
10781114 return (posterior ,)
10791115 return AutoencoderKLOutput (latent_dist = posterior )
10801116
1081- def _decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1117+ def _decode (
1118+ self , z : torch .Tensor , temb : Optional [torch .Tensor ] = None , return_dict : bool = True
1119+ ) -> Union [DecoderOutput , torch .Tensor ]:
10821120 batch_size , num_channels , num_frames , height , width = z .shape
10831121 tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
10841122 tile_latent_min_width = self .tile_sample_stride_width // self .spatial_compression_ratio
10851123
10861124 if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
1087- return self .tiled_decode (z , return_dict = return_dict )
1125+ return self .tiled_decode (z , temb , return_dict = return_dict )
10881126
10891127 if self .use_framewise_decoding :
10901128 # TODO(aryan): requires investigation
@@ -1094,15 +1132,17 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
10941132 "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
10951133 )
10961134 else :
1097- dec = self .decoder (z )
1135+ dec = self .decoder (z , temb )
10981136
10991137 if not return_dict :
11001138 return (dec ,)
11011139
11021140 return DecoderOutput (sample = dec )
11031141
11041142 @apply_forward_hook
1105- def decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1143+ def decode (
1144+ self , z : torch .Tensor , temb : Optional [torch .Tensor ] = None , return_dict : bool = True
1145+ ) -> Union [DecoderOutput , torch .Tensor ]:
11061146 """
11071147 Decode a batch of images.
11081148
@@ -1117,10 +1157,15 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
11171157 returned.
11181158 """
11191159 if self .use_slicing and z .shape [0 ] > 1 :
1120- decoded_slices = [self ._decode (z_slice ).sample for z_slice in z .split (1 )]
1160+ if temb is not None :
1161+ decoded_slices = [
1162+ self ._decode (z_slice , t_slice ).sample for z_slice , t_slice in (z .split (1 ), temb .split (1 ))
1163+ ]
1164+ else :
1165+ decoded_slices = [self ._decode (z_slice ).sample for z_slice in z .split (1 )]
11211166 decoded = torch .cat (decoded_slices )
11221167 else :
1123- decoded = self ._decode (z ).sample
1168+ decoded = self ._decode (z , temb ).sample
11241169
11251170 if not return_dict :
11261171 return (decoded ,)
@@ -1202,7 +1247,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12021247 enc = torch .cat (result_rows , dim = 3 )[:, :, :, :latent_height , :latent_width ]
12031248 return enc
12041249
1205- def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1250+ def tiled_decode (
1251+ self , z : torch .Tensor , temb : Optional [torch .Tensor ], return_dict : bool = True
1252+ ) -> Union [DecoderOutput , torch .Tensor ]:
12061253 r"""
12071254 Decode a batch of images using a tiled decoder.
12081255
@@ -1243,7 +1290,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
12431290 "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
12441291 )
12451292 else :
1246- time = self .decoder (z [:, :, :, i : i + tile_latent_min_height , j : j + tile_latent_min_width ])
1293+ time = self .decoder (
1294+ z [:, :, :, i : i + tile_latent_min_height , j : j + tile_latent_min_width ], temb
1295+ )
12471296
12481297 row .append (time )
12491298 rows .append (row )
@@ -1271,6 +1320,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
12711320 def forward (
12721321 self ,
12731322 sample : torch .Tensor ,
1323+ temb : Optional [torch .Tensor ] = None ,
12741324 sample_posterior : bool = False ,
12751325 return_dict : bool = True ,
12761326 generator : Optional [torch .Generator ] = None ,
@@ -1281,7 +1331,7 @@ def forward(
12811331 z = posterior .sample (generator = generator )
12821332 else :
12831333 z = posterior .mode ()
1284- dec = self .decode (z )
1334+ dec = self .decode (z , temb )
12851335 if not return_dict :
12861336 return (dec ,)
12871337 return dec
0 commit comments