Skip to content

Commit df1eacb

Browse files
debugging
1 parent 93f2f40 commit df1eacb

File tree

4 files changed

+120
-72
lines changed

4 files changed

+120
-72
lines changed

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,6 @@ def forward(
135135
scales = self._expand_conditioning_scale(conditioning_scale)
136136
result = []
137137
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
138-
# print(block_idx, "scale=", scale)
139-
# print("control_hidden_states.shape=", control_hidden_states.shape)
140-
# breakpoint()
141138
control_hidden_states = block(
142139
hidden_states=control_hidden_states,
143140
encoder_hidden_states=encoder_hidden_states,
@@ -147,6 +144,7 @@ def forward(
147144
extra_pos_emb=extra_pos_emb,
148145
attention_mask=attention_mask,
149146
controlnet_residual=None,
147+
block_idx=block_idx,
150148
)
151149
result.append(control_hidden_states * scale)
152150
return result

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ def __init__(
395395
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
396396

397397
# NOTE: zero conv for CosmosControlNet
398+
self.before_proj = None
399+
self.after_proj = None
398400
if before_proj:
399401
# TODO: check hint_dim in i4
400402
self.before_proj = nn.Linear(hidden_size, hidden_size)
@@ -411,7 +413,12 @@ def forward(
411413
extra_pos_emb: Optional[torch.Tensor] = None,
412414
attention_mask: Optional[torch.Tensor] = None,
413415
controlnet_residual: Optional[torch.Tensor] = None,
416+
block_idx: Optional[int] = None,
414417
) -> torch.Tensor:
418+
if self.before_proj is not None:
419+
hidden_states = self.before_proj(hidden_states)
420+
print(f"before_proj, block_idx={block_idx}")
421+
415422
if extra_pos_emb is not None:
416423
hidden_states = hidden_states + extra_pos_emb
417424

@@ -434,8 +441,13 @@ def forward(
434441

435442
if controlnet_residual is not None:
436443
# NOTE: this is assumed to be scaled by the controlnet
444+
# print("controlnet_residual")
437445
hidden_states += controlnet_residual
438446

447+
if self.after_proj is not None:
448+
hidden_states = self.after_proj(hidden_states)
449+
print(f"after_proj, block_idx={block_idx}")
450+
439451
return hidden_states
440452

441453

@@ -745,11 +757,10 @@ def prepare_inputs(
745757
else:
746758
assert False
747759

748-
text_context, img_context = encoder_hidden_states
760+
text_context, img_context = encoder_hidden_states if isinstance(encoder_hidden_states, tuple) else (encoder_hidden_states, None)
749761
if self.config.use_crossattn_projection:
750762
text_context = self.crossattn_proj(text_context)
751763

752-
# TODO: project img_context
753764
if img_context is not None and self.config.img_context_dim:
754765
img_context = self.img_context_proj(img_context)
755766

@@ -760,7 +771,8 @@ def prepare_inputs(
760771
"image_rotary_emb": image_rotary_emb,
761772
"extra_pos_emb": extra_pos_emb,
762773
"attention_mask": attention_mask,
763-
"encoder_hidden_states": (text_context, img_context),
774+
# TODO: improve
775+
"encoder_hidden_states": (text_context, img_context) if isinstance(encoder_hidden_states, tuple) else text_context,
764776
"num_frames": num_frames,
765777
"post_patch_num_frames": post_patch_num_frames,
766778
"post_patch_height": post_patch_height,
@@ -780,22 +792,24 @@ def forward(
780792
padding_mask: Optional[torch.Tensor] = None,
781793
return_dict: bool = True,
782794
) -> torch.Tensor:
783-
if prepared_inputs is None:
784-
prepared_inputs = self.prepare_inputs(
785-
hidden_states=hidden_states,
786-
timestep=timestep,
787-
encoder_hidden_states=encoder_hidden_states,
788-
block_controlnet_hidden_states=block_controlnet_hidden_states,
789-
attention_mask=attention_mask,
790-
fps=fps,
791-
condition_mask=condition_mask,
792-
padding_mask=padding_mask,
793-
return_dict=return_dict,
794-
)
795-
return self._forward(prepared_inputs, block_controlnet_hidden_states=block_controlnet_hidden_states, return_dict=return_dict)
795+
prepared_inputs = self.prepare_inputs(
796+
hidden_states=hidden_states,
797+
timestep=timestep,
798+
encoder_hidden_states=encoder_hidden_states,
799+
block_controlnet_hidden_states=block_controlnet_hidden_states,
800+
attention_mask=attention_mask,
801+
fps=fps,
802+
condition_mask=condition_mask,
803+
padding_mask=padding_mask,
804+
)
805+
806+
return self._forward(
807+
prepared_inputs,
808+
block_controlnet_hidden_states=block_controlnet_hidden_states,
809+
return_dict=return_dict,
810+
)
796811

797-
def _forward(self, prepared_inputs: Optional[Dict[str, Any]] = None, block_controlnet_hidden_states: Optional[List[torch.Tensor]] = None, return_dict: bool = True) -> torch.Tensor:
798-
# NOTE: in i4 controlnet_blocks are now computed ...
812+
def _forward(self, prepared_inputs: Dict[str, Any], block_controlnet_hidden_states: Optional[List[torch.Tensor]] = None, return_dict: bool = True) -> torch.Tensor:
799813
controlnet_block_index_map = {}
800814
if block_controlnet_hidden_states is not None:
801815
n_blocks = len(self.transformer_blocks)
@@ -812,6 +826,8 @@ def _forward(self, prepared_inputs: Optional[Dict[str, Any]] = None, block_contr
812826
# 5. Transformer blocks
813827
for block_idx, block in enumerate(self.transformer_blocks):
814828
controlnet_residual = controlnet_block_index_map.get(block_idx)
829+
if controlnet_residual is not None:
830+
print("*", block_idx, "controlnet_residual")
815831
if torch.is_grad_enabled() and self.gradient_checkpointing:
816832
hidden_states = self._gradient_checkpointing_func(
817833
block,

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ def prepare_latents(
460460
num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1
461461
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
462462
cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0
463-
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
463+
# cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
464+
cond_mask = zeros_padding # TODO removeme
464465

465466
return (
466467
latents,

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py

Lines changed: 83 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,48 @@ def retrieve_latents(
7474
else:
7575
raise AttributeError("Could not access latents of provided encoder_output")
7676

77+
# TODO: move this to a utility module aka Transfer2_5 model ?
78+
def transfer2_5_forward(
79+
transformer,
80+
controlnet,
81+
in_latents,
82+
controls_latents,
83+
controls_conditioning_scale,
84+
in_timestep,
85+
encoder_hidden_states,
86+
cond_mask,
87+
padding_mask,
88+
):
89+
control_blocks = None
90+
prepared_inputs = transformer.prepare_inputs(
91+
hidden_states=in_latents,
92+
condition_mask=cond_mask,
93+
timestep=in_timestep,
94+
encoder_hidden_states=encoder_hidden_states,
95+
padding_mask=padding_mask,
96+
)
97+
if controls_latents is not None:
98+
control_blocks = controlnet(
99+
controls_latents=controls_latents,
100+
latents=in_latents,
101+
conditioning_scale=controls_conditioning_scale,
102+
condition_mask=cond_mask,
103+
padding_mask=padding_mask,
104+
encoder_hidden_states=prepared_inputs["encoder_hidden_states"],
105+
temb=prepared_inputs["temb"],
106+
embedded_timestep=prepared_inputs["embedded_timestep"],
107+
image_rotary_emb=prepared_inputs["image_rotary_emb"],
108+
extra_pos_emb=prepared_inputs["extra_pos_emb"],
109+
attention_mask=prepared_inputs["attention_mask"],
110+
)
111+
112+
noise_pred = transformer._forward(
113+
prepared_inputs=prepared_inputs,
114+
block_controlnet_hidden_states=control_blocks,
115+
return_dict=False,
116+
)[0]
117+
return noise_pred
118+
77119

78120
EXAMPLE_DOC_STRING = """
79121
Examples:
@@ -227,7 +269,6 @@ def __init__(
227269

228270
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
229271
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
230-
# breakpoint()
231272
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
232273

233274
latents_mean = (
@@ -470,8 +511,10 @@ def prepare_latents(
470511

471512
num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1
472513
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
473-
cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0
474-
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
514+
# cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0
515+
# TODO: modify cond_mask per chunk
516+
# cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
517+
cond_mask = zeros_padding # TODO this is what i4 uses
475518

476519
return (
477520
latents,
@@ -569,7 +612,8 @@ def __call__(
569612
width: Optional[int] = None,
570613
num_frames: int = 93,
571614
num_inference_steps: int = 36,
572-
guidance_scale: float = 7.0,
615+
# guidance_scale: float = 7.0, # TODO: check default
616+
guidance_scale: float = 3.0,
573617
num_videos_per_prompt: Optional[int] = 1,
574618
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
575619
latents: Optional[torch.Tensor] = None,
@@ -676,16 +720,21 @@ def __call__(
676720

677721
if width is None:
678722
frame = image or video[0] if image or video else None
723+
if frame is None and controls is not None:
724+
frame = controls[0] if isinstance(controls, list) else controls
725+
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
726+
frame = controls[0]
727+
679728
if frame is None:
680-
width = (height + 16) * (1280/720)
729+
width = int((height + 16) * (1280/720))
681730
elif isinstance(frame, PIL.Image.Image):
682731
width = int((height + 16) * (frame.width / frame.height))
683732
else:
684733
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
685734

686735
# Check inputs. Raise error if not correct
687736
print("width=", width, "height=", height)
688-
breakpoint()
737+
# breakpoint()
689738
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
690739

691740
self._guidance_scale = guidance_scale
@@ -729,6 +778,7 @@ def __call__(
729778
)
730779
# TODO(migmartin): add img ref to prompt_embeds via siglip if provided
731780
encoder_hidden_states = (prompt_embeds, None)
781+
neg_encoder_hidden_states = (negative_prompt_embeds, None)
732782

733783
vae_dtype = self.vae.dtype
734784
transformer_dtype = self.transformer.dtype
@@ -815,51 +865,37 @@ def __call__(
815865

816866
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
817867
in_latents = in_latents.to(transformer_dtype)
818-
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
819-
control_blocks = None
820-
821-
prepared_inputs = self.transformer.prepare_inputs(
822-
hidden_states=in_latents,
823-
condition_mask=cond_mask,
824-
timestep=in_timestep,
868+
# in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
869+
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
870+
in_latents = (0.5 * torch.ones((1, 16, 24, 88, 120))).cuda().to(dtype=transformer_dtype)
871+
in_timestep = (torch.ones((1, 1, 24, 1, 1)) * 0.966).cuda().to(dtype=transformer_dtype)
872+
breakpoint()
873+
noise_pred = transfer2_5_forward(
874+
transformer=self.transformer,
875+
controlnet=self.controlnet,
876+
in_latents=in_latents,
877+
controls_latents=controls_latents,
878+
controls_conditioning_scale=controls_conditioning_scale,
879+
in_timestep=in_timestep,
825880
encoder_hidden_states=encoder_hidden_states,
826-
padding_mask=padding_mask,
881+
cond_mask=cond_mask,
882+
padding_mask=padding_mask
827883
)
828-
# import IPython; IPython.embed()
829-
# breakpoint()
830-
if controls is not None:
831-
control_blocks = self.controlnet(
832-
controls_latents=controls_latents,
833-
latents=in_latents,
834-
conditioning_scale=controls_conditioning_scale,
835-
condition_mask=cond_mask,
836-
padding_mask=padding_mask,
837-
# TODO: before or after projection?
838-
# encoder_hidden_states=encoder_hidden_states, # before
839-
# TODO: pass as prepared_inputs dict ?
840-
encoder_hidden_states=prepared_inputs["encoder_hidden_states"], # after
841-
temb=prepared_inputs["temb"],
842-
embedded_timestep=prepared_inputs["embedded_timestep"],
843-
image_rotary_emb=prepared_inputs["image_rotary_emb"],
844-
extra_pos_emb=prepared_inputs["extra_pos_emb"],
845-
attention_mask=prepared_inputs["attention_mask"],
846-
)
847-
848-
# breakpoint()
849-
noise_pred = self.transformer._forward(
850-
prepared_inputs=prepared_inputs,
851-
block_controlnet_hidden_states=control_blocks,
852-
return_dict=False,
853-
)[0]
854-
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
855884
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
885+
breakpoint()
856886

857887
if self.do_classifier_free_guidance:
858-
noise_pred_neg = self.transformer._forward(
859-
prepared_inputs=prepared_inputs,
860-
block_controlnet_hidden_states=control_blocks,
861-
return_dict=False,
862-
)[0]
888+
noise_pred_neg = transfer2_5_forward(
889+
transformer=self.transformer,
890+
controlnet=self.controlnet,
891+
in_latents=in_latents,
892+
controls_latents=controls_latents,
893+
controls_conditioning_scale=controls_conditioning_scale,
894+
in_timestep=in_timestep,
895+
encoder_hidden_states=neg_encoder_hidden_states,
896+
cond_mask=cond_mask,
897+
padding_mask=padding_mask
898+
)
863899
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
864900
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
865901
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
@@ -902,10 +938,7 @@ def __call__(
902938
# vid = self.safety_checker.check_video_safety(vid)
903939
video_batch.append(vid)
904940
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
905-
try:
906-
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
907-
except:
908-
breakpoint()
941+
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
909942
video = self.video_processor.postprocess_video(video, output_type=output_type)
910943
else:
911944
video = latents

0 commit comments

Comments
 (0)