@@ -1005,6 +1005,11 @@ def __call__(
10051005
10061006 if start_index > 0 :
10071007 conditioning_mask = conditioning_mask .gather (1 , video_ids [:, 0 ])
1008+ conditioning_mask_model_input = (
1009+ torch .cat ([conditioning_mask , conditioning_mask ])
1010+ if self .do_classifier_free_guidance
1011+ else conditioning_mask
1012+ )
10081013
10091014 video_ids = self ._scale_video_ids (
10101015 video_ids ,
@@ -1024,137 +1029,77 @@ def __call__(
10241029 num_warmup_steps = max (len (inner_timesteps ) - inner_num_inference_steps * self .scheduler .order , 0 )
10251030 self ._num_timesteps = len (inner_timesteps )
10261031
1027- if start_index == 0 :
1028- with self .progress_bar (total = inner_num_inference_steps ) as progress_bar :
1029- for i , t in enumerate (inner_timesteps ):
1030- if self .interrupt :
1031- continue
1032-
1033- self ._current_timestep = t
1034- latent_model_input = torch .cat ([latent_chunk ] * 2 ) if self .do_classifier_free_guidance else latent_chunk
1035- latent_model_input = latent_model_input .to (prompt_embeds .dtype )
1036- timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
1037-
1038- with self .transformer .cache_context ("cond_uncond" ):
1039- noise_pred = self .transformer (
1040- hidden_states = latent_model_input ,
1041- encoder_hidden_states = prompt_embeds ,
1042- timestep = timestep ,
1043- encoder_attention_mask = prompt_attention_mask ,
1044- video_coords = video_ids ,
1045- attention_kwargs = attention_kwargs ,
1046- return_dict = False ,
1047- )[0 ]
1048-
1049- if self .do_classifier_free_guidance :
1050- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1051- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1052- timestep , _ = timestep .chunk (2 )
1053-
1054- if self .guidance_rescale > 0 :
1055- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1056- noise_pred = rescale_noise_cfg (
1057- noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale
1058- )
1059-
1060- latent_chunk = self .scheduler .step (
1061- - noise_pred , t , latent_chunk , per_token_timesteps = timestep , return_dict = False
1062- )[0 ]
1032+ with self .progress_bar (total = inner_num_inference_steps ) as progress_bar :
1033+ for i , t in enumerate (inner_timesteps ):
1034+ if self .interrupt :
1035+ continue
10631036
1064- if callback_on_step_end is not None :
1065- callback_kwargs = {}
1066- for k in callback_on_step_end_tensor_inputs :
1067- callback_kwargs [k ] = locals ()[k ]
1068- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
1069-
1070- latent_chunk = callback_outputs .pop ("latents" , latent_chunk )
1071- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
1072-
1073- # call the callback, if provided
1074- if i == len (inner_timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
1075- progress_bar .update ()
1076-
1077- if XLA_AVAILABLE :
1078- xm .mark_step ()
1079-
1080- tile_out_latents = self ._unpack_latents (
1081- latent_chunk ,
1082- latent_tile_num_frames ,
1083- latent_tile_height ,
1084- latent_tile_width ,
1085- self .transformer_spatial_patch_size ,
1086- self .transformer_temporal_patch_size ,
1087- )
1088- first_tile_out_latents = tile_out_latents .clone ()
1089- else :
1090- with self .progress_bar (total = inner_num_inference_steps ) as progress_bar :
1091- for i , t in enumerate (inner_timesteps ):
1092- if self .interrupt :
1093- continue
1094-
1095- self ._current_timestep = t
1096- latent_model_input = torch .cat ([latent_chunk ] * 2 ) if self .do_classifier_free_guidance else latent_chunk
1097- latent_model_input = latent_model_input .to (prompt_embeds .dtype )
1098- conditioning_mask_model_input = (
1099- torch .cat ([conditioning_mask , conditioning_mask ])
1100- if self .do_classifier_free_guidance
1101- else conditioning_mask
1102- )
1103- timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
1037+ self ._current_timestep = t
1038+ latent_model_input = torch .cat ([latent_chunk ] * 2 ) if self .do_classifier_free_guidance else latent_chunk
1039+ latent_model_input = latent_model_input .to (prompt_embeds .dtype )
1040+ timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
1041+ if start_index > 0 :
11041042 timestep = torch .min (timestep , (1 - conditioning_mask_model_input ) * 1000.0 )
11051043
1106- with self .transformer .cache_context ("cond_uncond" ):
1107- noise_pred = self .transformer (
1108- hidden_states = latent_model_input ,
1109- encoder_hidden_states = prompt_embeds ,
1110- timestep = timestep ,
1111- encoder_attention_mask = prompt_attention_mask ,
1112- video_coords = video_ids ,
1113- attention_kwargs = attention_kwargs ,
1114- return_dict = False ,
1115- )[0 ]
1116-
1117- if self .do_classifier_free_guidance :
1118- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1119- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1120- timestep , _ = timestep .chunk (2 )
1121-
1122- if self .guidance_rescale > 0 :
1123- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1124- noise_pred = rescale_noise_cfg (
1125- noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale
1126- )
1127-
1128- denoised_latent_chunk = self .scheduler .step (
1129- - noise_pred , t , latent_chunk , per_token_timesteps = timestep , return_dict = False
1044+ with self .transformer .cache_context ("cond_uncond" ):
1045+ noise_pred = self .transformer (
1046+ hidden_states = latent_model_input ,
1047+ encoder_hidden_states = prompt_embeds ,
1048+ timestep = timestep ,
1049+ encoder_attention_mask = prompt_attention_mask ,
1050+ video_coords = video_ids ,
1051+ attention_kwargs = attention_kwargs ,
1052+ return_dict = False ,
11301053 )[0 ]
1054+
1055+ if self .do_classifier_free_guidance :
1056+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1057+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1058+ timestep , _ = timestep .chunk (2 )
1059+
1060+ if self .guidance_rescale > 0 :
1061+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1062+ noise_pred = rescale_noise_cfg (
1063+ noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale
1064+ )
1065+
1066+ denoised_latent_chunk = self .scheduler .step (
1067+ - noise_pred , t , latent_chunk , per_token_timesteps = timestep , return_dict = False
1068+ )[0 ]
1069+ if start_index == 0 :
1070+ latent_chunk = denoised_latent_chunk
1071+ else :
11311072 tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask )).unsqueeze (- 1 )
11321073 latent_chunk = torch .where (tokens_to_denoise_mask , denoised_latent_chunk , latent_chunk )
11331074
1134- if callback_on_step_end is not None :
1135- callback_kwargs = {}
1136- for k in callback_on_step_end_tensor_inputs :
1137- callback_kwargs [k ] = locals ()[k ]
1138- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
1139-
1140- latent_chunk = callback_outputs .pop ("latents" , latent_chunk )
1141- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
1142-
1143- # call the callback, if provided
1144- if i == len (inner_timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
1145- progress_bar .update ()
1146-
1147- if XLA_AVAILABLE :
1148- xm .mark_step ()
1149-
1150- latent_chunk = self ._unpack_latents (
1151- latent_chunk ,
1152- total_latent_num_frames ,
1153- latent_tile_height ,
1154- latent_tile_width ,
1155- self .transformer_spatial_patch_size ,
1156- self .transformer_temporal_patch_size ,
1157- )
1075+ if callback_on_step_end is not None :
1076+ callback_kwargs = {}
1077+ for k in callback_on_step_end_tensor_inputs :
1078+ callback_kwargs [k ] = locals ()[k ]
1079+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
1080+
1081+ latent_chunk = callback_outputs .pop ("latents" , latent_chunk )
1082+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
1083+
1084+ # call the callback, if provided
1085+ if i == len (inner_timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
1086+ progress_bar .update ()
1087+
1088+ if XLA_AVAILABLE :
1089+ xm .mark_step ()
1090+
1091+ latent_chunk = self ._unpack_latents (
1092+ latent_chunk ,
1093+ total_latent_num_frames ,
1094+ latent_tile_height ,
1095+ latent_tile_width ,
1096+ self .transformer_spatial_patch_size ,
1097+ self .transformer_temporal_patch_size ,
1098+ )
1099+
1100+ if start_index == 0 :
1101+ first_tile_out_latents = latent_chunk .clone ()
1102+ else :
11581103 # We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent
11591104 latent_chunk = latent_chunk [:, :, last_latent_tile_num_frames + 1 :, :, :]
11601105 latent_chunk = LTXLatentUpsamplePipeline .adain_filter_latent (latent_chunk , first_tile_out_latents , adain_factor )
@@ -1171,7 +1116,7 @@ def __call__(
11711116 ]
11721117 latent_chunk = torch .cat (parts , dim = 2 )
11731118
1174- tile_out_latents = latent_chunk .clone ()
1119+ tile_out_latents = latent_chunk .clone ()
11751120
11761121 tile_weights = self ._create_spatial_weights (tile_out_latents , v , h , horizontal_tiles , vertical_tiles , spatial_overlap )
11771122 final_latents [:, :, :, v_start :v_end , h_start :h_end ] += latent_chunk * tile_weights
0 commit comments