Skip to content

Commit 73ed1bc

Browse files
control working
1 parent f53df18 commit 73ed1bc

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def forward(
135135
scales = self._expand_conditioning_scale(conditioning_scale)
136136
result = []
137137
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
138-
control_hidden_states = block(
138+
control_hidden_states, control_proj = block(
139139
hidden_states=control_hidden_states,
140140
encoder_hidden_states=encoder_hidden_states,
141141
embedded_timestep=embedded_timestep,
@@ -147,5 +147,5 @@ def forward(
147147
block_idx=block_idx,
148148
latents=latents,
149149
)
150-
result.append(control_hidden_states * scale)
150+
result.append(control_proj * scale)
151151
return result

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def forward(
419419
controlnet_residual: Optional[torch.Tensor] = None,
420420
latents: Optional[torch.Tensor] = None,
421421
block_idx: Optional[int] = None,
422-
) -> torch.Tensor:
422+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
423423
if self.before_proj is not None:
424424
hidden_states = self.before_proj(hidden_states) + latents
425425
print(f"before_proj, block_idx={block_idx}")
@@ -445,8 +445,10 @@ def forward(
445445
hidden_states = hidden_states + gate * ff_output
446446

447447
if self.after_proj is not None:
448-
hidden_states = self.after_proj(hidden_states)
448+
assert controlnet_residual is None
449+
hs_proj = self.after_proj(hidden_states)
449450
print(f"after_proj, block_idx={block_idx}")
451+
return hidden_states, hs_proj
450452

451453
if controlnet_residual is not None:
452454
# NOTE: this is assumed to be scaled by the controlnet
@@ -846,7 +848,6 @@ def _forward(self, prepared_inputs: Dict[str, Any], block_controlnet_hidden_stat
846848
prepared_inputs["extra_pos_emb"],
847849
prepared_inputs["attention_mask"],
848850
controlnet_residual,
849-
latents,
850851
)
851852
else:
852853
hidden_states = block(
@@ -858,7 +859,6 @@ def _forward(self, prepared_inputs: Dict[str, Any], block_controlnet_hidden_stat
858859
prepared_inputs["extra_pos_emb"],
859860
prepared_inputs["attention_mask"],
860861
controlnet_residual,
861-
latents,
862862
)
863863

864864
temb = prepared_inputs["temb"]

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ def retrieve_latents(
7676

7777
# TODO: move this to a utility module aka Transfer2_5 model ?
7878
def transfer2_5_forward(
79-
transformer,
80-
controlnet,
81-
in_latents,
82-
controls_latents,
83-
controls_conditioning_scale,
84-
in_timestep,
85-
encoder_hidden_states,
86-
cond_mask,
87-
padding_mask,
79+
transformer: CosmosTransformer3DModel,
80+
controlnet: CosmosControlNetModel,
81+
in_latents: torch.Tensor,
82+
controls_latents: torch.Tensor,
83+
controls_conditioning_scale: list[float],
84+
in_timestep: torch.Tensor,
85+
encoder_hidden_states: tuple[torch.Tensor | None, torch.Tensor | None] | None,
86+
cond_mask: torch.Tensor,
87+
padding_mask: torch.Tensor,
8888
):
8989
control_blocks = None
9090
prepared_inputs = transformer.prepare_inputs(
@@ -97,7 +97,7 @@ def transfer2_5_forward(
9797
if controls_latents is not None:
9898
control_blocks = controlnet(
9999
controls_latents=controls_latents,
100-
latents=in_latents,
100+
latents=prepared_inputs["hidden_states"],
101101
conditioning_scale=controls_conditioning_scale,
102102
condition_mask=cond_mask,
103103
padding_mask=padding_mask,

0 commit comments

Comments
 (0)