@@ -199,7 +199,7 @@ def __init__(
199199 transformer : CosmosTransformer3DModel ,
200200 vae : AutoencoderKLWan ,
201201 scheduler : UniPCMultistepScheduler ,
202- controlnet : Optional [ CosmosControlNetModel ] = None ,
202+ controlnet : CosmosControlNetModel ,
203203 safety_checker : CosmosSafetyChecker = None ,
204204 ):
205205 super ().__init__ ()
@@ -474,23 +474,25 @@ def prepare_latents(
474474 cond_indicator ,
475475 )
476476
477- def _encode_controlnet_image (
477+ def _encode_controls (
478478 self ,
479- control_image : Optional [torch .Tensor ],
479+ controls : Optional [torch .Tensor ],
480480 height : int ,
481481 width : int ,
482482 num_frames : int ,
483483 dtype : torch .dtype ,
484484 device : torch .device ,
485485 ) -> Optional [torch .Tensor ]:
486- if control_image is None :
486+ if controls is None :
487487 return None
488488
489- control_video = self .video_processor .preprocess_video (control_image , height , width )
490- if control_video .shape [2 ] < num_frames :
491- n_pad_frames = num_frames - control_video .shape [2 ]
492- last_frame = control_video [:, :, - 1 :, :, :]
493- control_video = torch .cat ((control_video , last_frame .repeat (1 , 1 , n_pad_frames , 1 , 1 )), dim = 2 )
489+ # TODO: handle image differently?
490+ control_video = self .video_processor .preprocess_video (controls , height , width )
491+ # TODO: is this needed?
492+ # if control_video.shape[2] < num_frames:
493+ # n_pad_frames = num_frames - control_video.shape[2]
494+ # last_frame = control_video[:, :, -1:, :, :]
495+ # control_video = torch.cat((control_video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
494496
495497 control_video = control_video .to (device = device , dtype = self .vae .dtype )
496498 control_latents = [retrieve_latents (self .vae .encode (vid .unsqueeze (0 ))) for vid in control_video ]
@@ -568,8 +570,8 @@ def __call__(
568570 num_videos_per_prompt : Optional [int ] = 1 ,
569571 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
570572 latents : Optional [torch .Tensor ] = None ,
571- controlnet_conditioning_scale : Union [ float , List [float ]] = 1.0 ,
572- controlnet_conditioning_image : Optional [ PipelineImageInput ] = None ,
573+ controls : Optional [ PipelineImageInput | List [PipelineImageInput ]] = None ,
574+ controls_conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
573575 prompt_embeds : Optional [torch .Tensor ] = None ,
574576 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
575577 output_type : Optional [str ] = "pil" ,
@@ -623,10 +625,10 @@ def __call__(
623625 Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
624626 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
625627 tensor is generated by sampling using the supplied random `generator`.
626- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
627- The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
628- controlnet_conditioning_image (`PipelineImageInput`, *optional*):
628+ controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
629629 Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
630+ controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
631+ The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
630632 prompt_embeds (`torch.Tensor`, *optional*):
631633 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
632634 provided, text embeddings will be generated from `prompt` input argument.
@@ -765,19 +767,20 @@ def __call__(
765767 cond_timestep = torch .ones_like (cond_indicator ) * conditional_frame_timestep
766768 cond_mask = cond_mask .to (transformer_dtype )
767769
768- controlnet_latents = None
769- if self . controlnet is not None and controlnet_conditioning_image is not None :
770- controlnet_latents = self ._encode_controlnet_image (
771- control_image = controlnet_conditioning_image ,
770+ controls_latents = None
771+ if controls is not None :
772+ controls_latents = self ._encode_controls (
773+ controls ,
772774 height = height ,
773775 width = width ,
774776 num_frames = num_frames ,
775777 dtype = torch .float32 ,
776778 device = device ,
777779 )
778- if controlnet_latents .shape [0 ] != latents .shape [0 ]:
779- repeat_count = latents .shape [0 ] // controlnet_latents .shape [0 ]
780- controlnet_latents = controlnet_latents .repeat_interleave (repeat_count , dim = 0 )
780+ # TODO: checkme?
781+ # if controls_latents.shape[0] != latents.shape[0]:
782+ # repeat_count = latents.shape[0] // controls_latents.shape[0]
783+ # controls_latents = controls_latents.repeat_interleave(repeat_count, dim=0)
781784
782785 padding_mask = latents .new_zeros (1 , 1 , height , width , dtype = transformer_dtype )
783786
@@ -805,24 +808,24 @@ def __call__(
805808 in_latents = cond_mask * cond_latent + (1 - cond_mask ) * latents
806809 in_latents = in_latents .to (transformer_dtype )
807810 in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator ) * sigma_t
808- control_block_samples = None
809- if self . controlnet is not None and controlnet_latents is not None :
810- control_block_samples = self .controlnet (
811+ control_blocks = None
812+ if controls is not None :
813+ control_blocks = self .controlnet (
811814 hidden_states = in_latents ,
812- controlnet_cond = controlnet_latents .to (dtype = transformer_dtype ),
815+ controlnet_cond = controls_latents .to (dtype = transformer_dtype ),
813816 timestep = in_timestep ,
814817 encoder_hidden_states = prompt_embeds ,
815- conditioning_scale = controlnet_conditioning_scale ,
818+ conditioning_scale = controls_conditioning_scale ,
816819 return_dict = True ,
817- ). block_controlnet_hidden_states
818- control_block_samples = tuple ( residual . to ( dtype = transformer_dtype ) for residual in control_block_samples )
820+ )
821+
819822 noise_pred = self .transformer (
820823 hidden_states = in_latents ,
821824 condition_mask = cond_mask ,
822825 timestep = in_timestep ,
823826 encoder_hidden_states = prompt_embeds ,
824827 padding_mask = padding_mask ,
825- block_controlnet_hidden_states = control_block_samples ,
828+ block_controlnet_hidden_states = control_blocks ,
826829 return_dict = False ,
827830 )[0 ]
828831 # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
@@ -835,7 +838,7 @@ def __call__(
835838 timestep = in_timestep ,
836839 encoder_hidden_states = negative_prompt_embeds ,
837840 padding_mask = padding_mask ,
838- block_controlnet_hidden_states = control_block_samples ,
841+ block_controlnet_hidden_states = control_blocks ,
839842 return_dict = False ,
840843 )[0 ]
841844 # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
@@ -868,7 +871,8 @@ def __call__(
868871 latents_std = self .latents_std .to (latents .device , latents .dtype )
869872 latents = latents * latents_std + latents_mean
870873 video = self .vae .decode (latents .to (self .vae .dtype ), return_dict = False )[0 ]
871- video = self ._match_num_frames (video , num_frames )
874+ # TODO: checkme
875+ # video = self._match_num_frames(video, num_frames)
872876
873877 assert self .safety_checker is not None
874878 self .safety_checker .to (device )
@@ -892,6 +896,7 @@ def __call__(
892896
893897 return CosmosPipelineOutput (frames = video )
894898
899+ # TODO: checkme - this seems like a hack
895900 def _match_num_frames (self , video : torch .Tensor , target_num_frames : int ) -> torch .Tensor :
896901 if target_num_frames <= 0 or video .shape [2 ] == target_num_frames :
897902 return video
0 commit comments