Skip to content

Commit f928f3f

Browse files
cleanup + detail on neg_encoder_hidden_states
1 parent 73ed1bc commit f928f3f

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,11 @@ def __init__(self):
241241

242242
def compute_attn_i2v(
243243
self,
244-
attn: Attention, # TODO: CosmosAttention
244+
attn: Attention,
245245
hidden_states: torch.Tensor,
246246
img_context=None,
247247
attention_mask=None,
248248
):
249-
print("compute_attn_i2v", flush=True)
250249
q_img = attn.q_img(hidden_states)
251250
k_img = attn.k_img(img_context)
252251
v_img = attn.v_img(img_context)
@@ -294,10 +293,7 @@ def __call__(
294293
image_rotary_emb=image_rotary_emb,
295294
)
296295

297-
# TODO: fixme
298-
# NOTE: img_context should be zeros
299296
if img_context is not None:
300-
print("compute_attn_i2v", flush=True)
301297
img_out = self.compute_attn_i2v(
302298
attn=attn,
303299
hidden_states=hidden_states,
@@ -422,7 +418,7 @@ def forward(
422418
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
423419
if self.before_proj is not None:
424420
hidden_states = self.before_proj(hidden_states) + latents
425-
print(f"before_proj, block_idx={block_idx}")
421+
# print(f"before_proj, block_idx={block_idx}")
426422

427423
if extra_pos_emb is not None:
428424
hidden_states = hidden_states + extra_pos_emb
@@ -444,17 +440,18 @@ def forward(
444440
ff_output = self.ff(norm_hidden_states)
445441
hidden_states = hidden_states + gate * ff_output
446442

443+
if controlnet_residual is not None:
444+
assert self.after_proj is None
445+
# NOTE: this is assumed to be scaled by the controlnet
446+
# print("controlnet_residual", flush=True)
447+
hidden_states += controlnet_residual
448+
447449
if self.after_proj is not None:
448450
assert controlnet_residual is None
449451
hs_proj = self.after_proj(hidden_states)
450-
print(f"after_proj, block_idx={block_idx}")
452+
# print(f"after_proj, block_idx={block_idx}")
451453
return hidden_states, hs_proj
452454

453-
if controlnet_residual is not None:
454-
# NOTE: this is assumed to be scaled by the controlnet
455-
print("controlnet_residual", flush=True)
456-
hidden_states += controlnet_residual
457-
458455
return hidden_states
459456

460457

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,6 @@ def __call__(
611611
width: Optional[int] = None,
612612
num_frames: int = 93,
613613
num_inference_steps: int = 36,
614-
# guidance_scale: float = 7.0, # TODO: check default
615614
guidance_scale: float = 3.0,
616615
num_videos_per_prompt: Optional[int] = 1,
617616
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -658,7 +657,7 @@ def __call__(
658657
num_inference_steps (`int`, defaults to `35`):
659658
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
660659
expense of slower inference.
661-
guidance_scale (`float`, defaults to `7.0`):
660+
guidance_scale (`float`, defaults to `3.0`):
662661
Guidance scale as defined in [Classifier-Free Diffusion
663662
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
664663
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
@@ -773,13 +772,16 @@ def __call__(
773772
device=device,
774773
max_sequence_length=max_sequence_length,
775774
)
776-
# TODO(migmartin): add img ref to prompt_embeds via siglip if provided
777-
encoder_hidden_states = (prompt_embeds, None)
778-
neg_encoder_hidden_states = (negative_prompt_embeds, None)
779775

780776
vae_dtype = self.vae.dtype
781777
transformer_dtype = self.transformer.dtype
782778

779+
# TODO(migmartin): add img ref to prompt_embeds via siglip if image ref is provided
780+
img_context_ref = torch.zeros(1, 256, 1152).to(device=prompt_embeds.device, dtype=transformer_dtype)
781+
encoder_hidden_states = (prompt_embeds, img_context_ref)
782+
# NOTE: rojects/cosmos/transfer2/configs/vid2vid_transfer/defaults/conditioner.py L240
783+
neg_encoder_hidden_states = (negative_prompt_embeds, None)
784+
783785
num_frames_in = None
784786
if image is not None:
785787
if batch_size != 1:

0 commit comments

Comments
 (0)