@@ -187,20 +187,20 @@ def forward(
187187 hidden_states = self .norm (hidden_states )
188188 hidden_states = hidden_states .permute (0 , 3 , 4 , 2 , 1 ).reshape (batch_size * height * width , num_frames , channel )
189189
190- hidden_states = self .proj_in (hidden_states )
190+ hidden_states = self .proj_in (input = hidden_states )
191191
192192 # 2. Blocks
193193 for block in self .transformer_blocks :
194194 hidden_states = block (
195- hidden_states ,
195+ hidden_states = hidden_states ,
196196 encoder_hidden_states = encoder_hidden_states ,
197197 timestep = timestep ,
198198 cross_attention_kwargs = cross_attention_kwargs ,
199199 class_labels = class_labels ,
200200 )
201201
202202 # 3. Output
203- hidden_states = self .proj_out (hidden_states )
203+ hidden_states = self .proj_out (input = hidden_states )
204204 hidden_states = (
205205 hidden_states [None , None , :]
206206 .reshape (batch_size , height , width , num_frames , channel )
@@ -344,15 +344,15 @@ def custom_forward(*inputs):
344344 )
345345
346346 else :
347- hidden_states = resnet (hidden_states , temb )
347+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
348348
349349 hidden_states = motion_module (hidden_states , num_frames = num_frames )
350350
351351 output_states = output_states + (hidden_states ,)
352352
353353 if self .downsamplers is not None :
354354 for downsampler in self .downsamplers :
355- hidden_states = downsampler (hidden_states )
355+ hidden_states = downsampler (hidden_states = hidden_states )
356356
357357 output_states = output_states + (hidden_states ,)
358358
@@ -531,25 +531,18 @@ def custom_forward(*inputs):
531531 temb ,
532532 ** ckpt_kwargs ,
533533 )
534- hidden_states = attn (
535- hidden_states ,
536- encoder_hidden_states = encoder_hidden_states ,
537- cross_attention_kwargs = cross_attention_kwargs ,
538- attention_mask = attention_mask ,
539- encoder_attention_mask = encoder_attention_mask ,
540- return_dict = False ,
541- )[0 ]
542534 else :
543- hidden_states = resnet (hidden_states , temb )
535+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
536+
537+ hidden_states = attn (
538+ hidden_states = hidden_states ,
539+ encoder_hidden_states = encoder_hidden_states ,
540+ cross_attention_kwargs = cross_attention_kwargs ,
541+ attention_mask = attention_mask ,
542+ encoder_attention_mask = encoder_attention_mask ,
543+ return_dict = False ,
544+ )[0 ]
544545
545- hidden_states = attn (
546- hidden_states ,
547- encoder_hidden_states = encoder_hidden_states ,
548- cross_attention_kwargs = cross_attention_kwargs ,
549- attention_mask = attention_mask ,
550- encoder_attention_mask = encoder_attention_mask ,
551- return_dict = False ,
552- )[0 ]
553546 hidden_states = motion_module (
554547 hidden_states ,
555548 num_frames = num_frames ,
@@ -563,7 +556,7 @@ def custom_forward(*inputs):
563556
564557 if self .downsamplers is not None :
565558 for downsampler in self .downsamplers :
566- hidden_states = downsampler (hidden_states )
559+ hidden_states = downsampler (hidden_states = hidden_states )
567560
568561 output_states = output_states + (hidden_states ,)
569562
@@ -757,33 +750,26 @@ def custom_forward(*inputs):
757750 temb ,
758751 ** ckpt_kwargs ,
759752 )
760- hidden_states = attn (
761- hidden_states ,
762- encoder_hidden_states = encoder_hidden_states ,
763- cross_attention_kwargs = cross_attention_kwargs ,
764- attention_mask = attention_mask ,
765- encoder_attention_mask = encoder_attention_mask ,
766- return_dict = False ,
767- )[0 ]
768753 else :
769- hidden_states = resnet (hidden_states , temb )
754+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
755+
756+ hidden_states = attn (
757+ hidden_states = hidden_states ,
758+ encoder_hidden_states = encoder_hidden_states ,
759+ cross_attention_kwargs = cross_attention_kwargs ,
760+ attention_mask = attention_mask ,
761+ encoder_attention_mask = encoder_attention_mask ,
762+ return_dict = False ,
763+ )[0 ]
770764
771- hidden_states = attn (
772- hidden_states ,
773- encoder_hidden_states = encoder_hidden_states ,
774- cross_attention_kwargs = cross_attention_kwargs ,
775- attention_mask = attention_mask ,
776- encoder_attention_mask = encoder_attention_mask ,
777- return_dict = False ,
778- )[0 ]
779765 hidden_states = motion_module (
780766 hidden_states ,
781767 num_frames = num_frames ,
782768 )
783769
784770 if self .upsamplers is not None :
785771 for upsampler in self .upsamplers :
786- hidden_states = upsampler (hidden_states , upsample_size )
772+ hidden_states = upsampler (hidden_states = hidden_states , output_size = upsample_size )
787773
788774 return hidden_states
789775
@@ -929,13 +915,13 @@ def custom_forward(*inputs):
929915 create_custom_forward (resnet ), hidden_states , temb
930916 )
931917 else :
932- hidden_states = resnet (hidden_states , temb )
918+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
933919
934920 hidden_states = motion_module (hidden_states , num_frames = num_frames )
935921
936922 if self .upsamplers is not None :
937923 for upsampler in self .upsamplers :
938- hidden_states = upsampler (hidden_states , upsample_size )
924+ hidden_states = upsampler (hidden_states = hidden_states , output_size = upsample_size )
939925
940926 return hidden_states
941927
@@ -1080,10 +1066,19 @@ def forward(
10801066 if cross_attention_kwargs .get ("scale" , None ) is not None :
10811067 logger .warning ("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored." )
10821068
1083- hidden_states = self .resnets [0 ](hidden_states , temb )
1069+ hidden_states = self .resnets [0 ](input_tensor = hidden_states , temb = temb )
10841070
10851071 blocks = zip (self .attentions , self .resnets [1 :], self .motion_modules )
10861072 for attn , resnet , motion_module in blocks :
1073+ hidden_states = attn (
1074+ hidden_states = hidden_states ,
1075+ encoder_hidden_states = encoder_hidden_states ,
1076+ cross_attention_kwargs = cross_attention_kwargs ,
1077+ attention_mask = attention_mask ,
1078+ encoder_attention_mask = encoder_attention_mask ,
1079+ return_dict = False ,
1080+ )[0 ]
1081+
10871082 if self .training and self .gradient_checkpointing :
10881083
10891084 def create_custom_forward (module , return_dict = None ):
@@ -1096,14 +1091,6 @@ def custom_forward(*inputs):
10961091 return custom_forward
10971092
10981093 ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
1099- hidden_states = attn (
1100- hidden_states ,
1101- encoder_hidden_states = encoder_hidden_states ,
1102- cross_attention_kwargs = cross_attention_kwargs ,
1103- attention_mask = attention_mask ,
1104- encoder_attention_mask = encoder_attention_mask ,
1105- return_dict = False ,
1106- )[0 ]
11071094 hidden_states = torch .utils .checkpoint .checkpoint (
11081095 create_custom_forward (motion_module ),
11091096 hidden_states ,
@@ -1117,19 +1104,11 @@ def custom_forward(*inputs):
11171104 ** ckpt_kwargs ,
11181105 )
11191106 else :
1120- hidden_states = attn (
1121- hidden_states ,
1122- encoder_hidden_states = encoder_hidden_states ,
1123- cross_attention_kwargs = cross_attention_kwargs ,
1124- attention_mask = attention_mask ,
1125- encoder_attention_mask = encoder_attention_mask ,
1126- return_dict = False ,
1127- )[0 ]
11281107 hidden_states = motion_module (
11291108 hidden_states ,
11301109 num_frames = num_frames ,
11311110 )
1132- hidden_states = resnet (hidden_states , temb )
1111+ hidden_states = resnet (input_tensor = hidden_states , temb = temb )
11331112
11341113 return hidden_states
11351114
0 commit comments