@@ -74,6 +74,48 @@ def retrieve_latents(
7474 else :
7575 raise AttributeError ("Could not access latents of provided encoder_output" )
7676
77+ # TODO: move this to a utility module aka Transfer2_5 model ?
78+ def transfer2_5_forward (
79+ transformer ,
80+ controlnet ,
81+ in_latents ,
82+ controls_latents ,
83+ controls_conditioning_scale ,
84+ in_timestep ,
85+ encoder_hidden_states ,
86+ cond_mask ,
87+ padding_mask ,
88+ ):
89+ control_blocks = None
90+ prepared_inputs = transformer .prepare_inputs (
91+ hidden_states = in_latents ,
92+ condition_mask = cond_mask ,
93+ timestep = in_timestep ,
94+ encoder_hidden_states = encoder_hidden_states ,
95+ padding_mask = padding_mask ,
96+ )
97+ if controls_latents is not None :
98+ control_blocks = controlnet (
99+ controls_latents = controls_latents ,
100+ latents = in_latents ,
101+ conditioning_scale = controls_conditioning_scale ,
102+ condition_mask = cond_mask ,
103+ padding_mask = padding_mask ,
104+ encoder_hidden_states = prepared_inputs ["encoder_hidden_states" ],
105+ temb = prepared_inputs ["temb" ],
106+ embedded_timestep = prepared_inputs ["embedded_timestep" ],
107+ image_rotary_emb = prepared_inputs ["image_rotary_emb" ],
108+ extra_pos_emb = prepared_inputs ["extra_pos_emb" ],
109+ attention_mask = prepared_inputs ["attention_mask" ],
110+ )
111+
112+ noise_pred = transformer ._forward (
113+ prepared_inputs = prepared_inputs ,
114+ block_controlnet_hidden_states = control_blocks ,
115+ return_dict = False ,
116+ )[0 ]
117+ return noise_pred
118+
77119
78120EXAMPLE_DOC_STRING = """
79121 Examples:
@@ -227,7 +269,6 @@ def __init__(
227269
228270 self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
229271 self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
230- # breakpoint()
231272 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
232273
233274 latents_mean = (
@@ -470,8 +511,10 @@ def prepare_latents(
470511
471512 num_cond_latent_frames = (num_frames_in - 1 ) // self .vae_scale_factor_temporal + 1
472513 cond_indicator = latents .new_zeros (1 , 1 , latents .size (2 ), 1 , 1 )
473- cond_indicator [:, :, 0 :num_cond_latent_frames ] = 1.0
474- cond_mask = cond_indicator * ones_padding + (1 - cond_indicator ) * zeros_padding
514+ # cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0
515+ # TODO: modify cond_mask per chunk
516+ # cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
517+ cond_mask = zeros_padding # TODO this is what i4 uses
475518
476519 return (
477520 latents ,
@@ -569,7 +612,8 @@ def __call__(
569612 width : Optional [int ] = None ,
570613 num_frames : int = 93 ,
571614 num_inference_steps : int = 36 ,
572- guidance_scale : float = 7.0 ,
615+ # guidance_scale: float = 7.0, # TODO: check default
616+ guidance_scale : float = 3.0 ,
573617 num_videos_per_prompt : Optional [int ] = 1 ,
574618 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
575619 latents : Optional [torch .Tensor ] = None ,
@@ -676,16 +720,21 @@ def __call__(
676720
677721 if width is None :
678722 frame = image or video [0 ] if image or video else None
723+ if frame is None and controls is not None :
724+ frame = controls [0 ] if isinstance (controls , list ) else controls
725+ if isinstance (frame , (torch .Tensor , np .ndarray )) and len (frame .shape ) == 4 :
726+ frame = controls [0 ]
727+
679728 if frame is None :
680- width = ( height + 16 ) * (1280 / 720 )
729+ width = int (( height + 16 ) * (1280 / 720 ) )
681730 elif isinstance (frame , PIL .Image .Image ):
682731 width = int ((height + 16 ) * (frame .width / frame .height ))
683732 else :
684733 width = int ((height + 16 ) * (frame .shape [2 ] / frame .shape [1 ])) # NOTE: assuming C H W
685734
686735 # Check inputs. Raise error if not correct
687736 print ("width=" , width , "height=" , height )
688- breakpoint ()
737+ # breakpoint()
689738 self .check_inputs (prompt , height , width , prompt_embeds , callback_on_step_end_tensor_inputs )
690739
691740 self ._guidance_scale = guidance_scale
@@ -729,6 +778,7 @@ def __call__(
729778 )
730779 # TODO(migmartin): add img ref to prompt_embeds via siglip if provided
731780 encoder_hidden_states = (prompt_embeds , None )
781+ neg_encoder_hidden_states = (negative_prompt_embeds , None )
732782
733783 vae_dtype = self .vae .dtype
734784 transformer_dtype = self .transformer .dtype
@@ -815,51 +865,37 @@ def __call__(
815865
816866 in_latents = cond_mask * cond_latent + (1 - cond_mask ) * latents
817867 in_latents = in_latents .to (transformer_dtype )
818- in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator ) * sigma_t
819- control_blocks = None
820-
821- prepared_inputs = self .transformer .prepare_inputs (
822- hidden_states = in_latents ,
823- condition_mask = cond_mask ,
824- timestep = in_timestep ,
868+ # in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
869+ # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
870+ in_latents = (0.5 * torch .ones ((1 , 16 , 24 , 88 , 120 ))).cuda ().to (dtype = transformer_dtype )
871+ in_timestep = (torch .ones ((1 , 1 , 24 , 1 , 1 )) * 0.966 ).cuda ().to (dtype = transformer_dtype )
872+ breakpoint ()
873+ noise_pred = transfer2_5_forward (
874+ transformer = self .transformer ,
875+ controlnet = self .controlnet ,
876+ in_latents = in_latents ,
877+ controls_latents = controls_latents ,
878+ controls_conditioning_scale = controls_conditioning_scale ,
879+ in_timestep = in_timestep ,
825880 encoder_hidden_states = encoder_hidden_states ,
826- padding_mask = padding_mask ,
881+ cond_mask = cond_mask ,
882+ padding_mask = padding_mask
827883 )
828- # import IPython; IPython.embed()
829- # breakpoint()
830- if controls is not None :
831- control_blocks = self .controlnet (
832- controls_latents = controls_latents ,
833- latents = in_latents ,
834- conditioning_scale = controls_conditioning_scale ,
835- condition_mask = cond_mask ,
836- padding_mask = padding_mask ,
837- # TODO: before or after projection?
838- # encoder_hidden_states=encoder_hidden_states, # before
839- # TODO: pass as prepared_inputs dict ?
840- encoder_hidden_states = prepared_inputs ["encoder_hidden_states" ], # after
841- temb = prepared_inputs ["temb" ],
842- embedded_timestep = prepared_inputs ["embedded_timestep" ],
843- image_rotary_emb = prepared_inputs ["image_rotary_emb" ],
844- extra_pos_emb = prepared_inputs ["extra_pos_emb" ],
845- attention_mask = prepared_inputs ["attention_mask" ],
846- )
847-
848- # breakpoint()
849- noise_pred = self .transformer ._forward (
850- prepared_inputs = prepared_inputs ,
851- block_controlnet_hidden_states = control_blocks ,
852- return_dict = False ,
853- )[0 ]
854- # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
855884 noise_pred = gt_velocity + noise_pred * (1 - cond_mask )
885+ breakpoint ()
856886
857887 if self .do_classifier_free_guidance :
858- noise_pred_neg = self .transformer ._forward (
859- prepared_inputs = prepared_inputs ,
860- block_controlnet_hidden_states = control_blocks ,
861- return_dict = False ,
862- )[0 ]
888+ noise_pred_neg = transfer2_5_forward (
889+ transformer = self .transformer ,
890+ controlnet = self .controlnet ,
891+ in_latents = in_latents ,
892+ controls_latents = controls_latents ,
893+ controls_conditioning_scale = controls_conditioning_scale ,
894+ in_timestep = in_timestep ,
895+ encoder_hidden_states = neg_encoder_hidden_states ,
896+ cond_mask = cond_mask ,
897+ padding_mask = padding_mask
898+ )
863899 # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
864900 noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask )
865901 noise_pred = noise_pred + self .guidance_scale * (noise_pred - noise_pred_neg )
@@ -902,10 +938,7 @@ def __call__(
902938 # vid = self.safety_checker.check_video_safety(vid)
903939 video_batch .append (vid )
904940 video = np .stack (video_batch ).astype (np .float32 ) / 255.0 * 2 - 1
905- try :
906- video = torch .from_numpy (video ).permute (0 , 4 , 1 , 2 , 3 )
907- except :
908- breakpoint ()
941+ video = torch .from_numpy (video ).permute (0 , 4 , 1 , 2 , 3 )
909942 video = self .video_processor .postprocess_video (video , output_type = output_type )
910943 else :
911944 video = latents
0 commit comments