Skip to content

Commit 5705dfa

Browse files
convert controlnet
1 parent eee28ce commit 5705dfa

File tree

3 files changed

+64
-63
lines changed

3 files changed

+64
-63
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -393,50 +393,35 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
393393
"use_crossattn_projection": True,
394394
"crossattn_proj_in_channels": 100352,
395395
"encoder_hidden_states_channels": 1024,
396-
"n_control_net_blocks": 4,
397396
"controlnet_block_every_n": 7,
398397
"img_context_dim": 1152,
399398
},
400399
}
401400

402401
CONTROLNET_CONFIGS = {
403402
"Cosmos-2.5-Transfer-General-2B": {
404-
"in_channels": 16 + 1,
403+
"n_controlnet_blocks": 4,
404+
"model_channels": 2048,
405+
"in_channels": 130,
405406
"num_attention_heads": 16,
406407
"attention_head_dim": 128,
407-
"num_layers": 4,
408+
"mlp_ratio": 4.0,
409+
"text_embed_dim": 1024,
410+
"adaln_lora_dim": 256,
408411
"patch_size": (1, 2, 2),
409-
"control_block_indices": (6, 13, 20, 27),
410412
},
411413
}
412414

413415
# TODO(migmartin): fix this, this is not correct
414416
CONTROLNET_KEYS_RENAME_DICT = {
415-
"controlnet_blocks": "control_blocks",
416-
"control_net_blocks": "control_blocks",
417-
"control_blocks.block": "control_blocks.",
418-
"control_blocks": "control_blocks",
419-
".linear": ".proj",
420-
".proj.0": ".proj",
421-
".proj.1": ".proj",
422-
"x_embedder_control": "patch_embed",
423-
"control_patch_embed": "patch_embed",
424-
"controlnet_patch_embed": "patch_embed",
425-
"control_embedder": "patch_embed",
417+
**TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0,
418+
"blocks": "blocks",
419+
"control_embedder.proj.1": "patch_embed.proj",
426420
}
427421

428422

429-
def rename_controlnet_blocks_(key: str, state_dict: Dict[str, Any]):
430-
block_index = int(key.split(".")[1].removeprefix("block"))
431-
new_key = key
432-
old_prefix = f"control_blocks.block{block_index}"
433-
new_prefix = f"control_blocks.{block_index}"
434-
new_key = new_prefix + new_key.removeprefix(old_prefix)
435-
state_dict[new_key] = state_dict.pop(key)
436-
437-
438423
CONTROLNET_SPECIAL_KEYS_REMAP = {
439-
"control_blocks.block": rename_controlnet_blocks_,
424+
**TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
440425
}
441426

442427
VAE_KEYS_RENAME_DICT = {
@@ -606,8 +591,6 @@ def convert_controlnet(transformer_type: str, state_dict: Dict[str, Any], weight
606591
new2old[new_key] = key
607592
update_state_dict_(state_dict, key, new_key)
608593

609-
breakpoint()
610-
611594
for key in list(state_dict.keys()):
612595
for special_key, handler_fn_inplace in CONTROLNET_SPECIAL_KEYS_REMAP.items():
613596
if special_key not in key:
@@ -832,12 +815,12 @@ def get_args():
832815
base_state_dict[k] = v
833816
assert len(base_state_dict.keys() & control_state_dict.keys()) == 0
834817

835-
transformer = convert_transformer(args.transformer_type, state_dict=base_state_dict, weights_only=weights_only)
836-
transformer = transformer.to(dtype=dtype)
837-
838818
controlnet = convert_controlnet(args.transformer_type, control_state_dict, weights_only=weights_only)
839819
controlnet = controlnet.to(dtype=dtype)
840820

821+
transformer = convert_transformer(args.transformer_type, state_dict=base_state_dict, weights_only=weights_only)
822+
transformer = transformer.to(dtype=dtype)
823+
841824
if not args.save_pipeline:
842825
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
843826
controlnet.save_pretrained(

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,49 +10,52 @@
1010
from ..modeling_utils import ModelMixin
1111
from ..transformers.transformer_cosmos import (
1212
CosmosPatchEmbed,
13+
CosmosTransformerBlock,
1314
)
1415
from .controlnet import zero_module
1516

1617

1718
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1819

1920

20-
class CosmosControlNetBlock(nn.Module):
21-
def __init__(self, hidden_size: int):
22-
super().__init__()
23-
self.proj = zero_module(nn.Linear(hidden_size, hidden_size, bias=True))
24-
25-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
26-
return self.proj(hidden_states)
27-
28-
2921
# TODO(migmartin): implement me
3022
# see i4/projects/cosmos/transfer2/networks/minimal_v4_lvg_dit_control_vace.py
3123
class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
3224
r"""
33-
Minimal ControlNet for Cosmos Transfer2.5.
34-
35-
This module projects encoded control latents into per-block residuals aligned with the
36-
`CosmosTransformer3DModel` hidden size. All projections are zero-initialized so the ControlNet
37-
starts neutral by default.
25+
ControlNet for Cosmos Transfer2.5.
3826
"""
3927

4028
@register_to_config
4129
def __init__(
4230
self,
31+
n_controlnet_blocks: int = 4,
4332
in_channels: int = 16,
33+
model_channels: int = 2048,
4434
num_attention_heads: int = 32,
4535
attention_head_dim: int = 128,
46-
num_layers: int = 4,
36+
mlp_ratio: float = 4.0,
37+
text_embed_dim: int = 1024,
38+
adaln_lora_dim: int = 256,
4739
patch_size: Tuple[int, int, int] = (1, 2, 2),
48-
control_block_indices: Tuple[int, ...] = (6, 13, 20, 27),
4940
):
5041
super().__init__()
51-
hidden_size = num_attention_heads * attention_head_dim
52-
53-
self.patch_embed = CosmosPatchEmbed(in_channels, hidden_size, patch_size, bias=False)
42+
self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)
5443
self.control_blocks = nn.ModuleList(
55-
CosmosControlNetBlock(hidden_size) for _ in range(num_layers)
44+
[
45+
CosmosTransformerBlock(
46+
num_attention_heads=num_attention_heads,
47+
attention_head_dim=attention_head_dim,
48+
cross_attention_dim=text_embed_dim,
49+
mlp_ratio=mlp_ratio,
50+
adaln_lora_dim=adaln_lora_dim,
51+
qk_norm="rms_norm",
52+
out_bias=False,
53+
img_context=True,
54+
before_proj=(block_idx == 0),
55+
after_proj=True,
56+
)
57+
for block_idx in range(n_controlnet_blocks)
58+
]
5659
)
5760

5861
def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float]]) -> List[float]:
@@ -61,7 +64,7 @@ def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float
6164
else:
6265
scales = [conditioning_scale] * len(self.control_blocks)
6366

64-
if len(scales) != len(self.control_blocks):
67+
if len(scales) < len(self.control_blocks):
6568
logger.warning(
6669
"Received %d control scales, but control network defines %d blocks. "
6770
"Scales will be trimmed or repeated to match.",
@@ -75,16 +78,25 @@ def forward(
7578
self,
7679
hidden_states: torch.Tensor,
7780
controlnet_cond: torch.Tensor,
78-
timestep: Optional[torch.Tensor] = None,
79-
encoder_hidden_states: Optional[torch.Tensor] = None,
8081
conditioning_scale: Union[float, List[float]] = 1.0,
81-
return_dict: bool = True,
8282
) -> List[torch.Tensor]:
83-
del hidden_states, timestep, encoder_hidden_states # not used in this minimal control path
84-
8583
control_hidden_states = self.patch_embed(controlnet_cond)
8684
control_hidden_states = control_hidden_states.flatten(1, 3)
8785

8886
scales = self._expand_conditioning_scale(conditioning_scale)
89-
control_residuals = tuple(block(control_hidden_states) * scale for block, scale in zip(self.control_blocks, scales))
90-
return control_residuals
87+
x = hidden_states
88+
89+
# NOTE: args to block
90+
# hidden_states: torch.Tensor,
91+
# encoder_hidden_states: torch.Tensor,
92+
# embedded_timestep: torch.Tensor,
93+
# temb: Optional[torch.Tensor] = None,
94+
# image_rotary_emb: Optional[torch.Tensor] = None,
95+
# extra_pos_emb: Optional[torch.Tensor] = None,
96+
# attention_mask: Optional[torch.Tensor] = None,
97+
# controlnet_residual: Optional[torch.Tensor] = None,
98+
result = []
99+
for block, scale in zip(self.control_blocks, scales):
100+
x = block(x)
101+
result.append(x * scale)
102+
return result

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ def __init__(
341341
qk_norm: str = "rms_norm",
342342
out_bias: bool = False,
343343
img_context: bool = False,
344+
before_proj: bool = False,
345+
after_proj: bool = False,
344346
) -> None:
345347
super().__init__()
346348

@@ -386,6 +388,13 @@ def __init__(
386388
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
387389
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
388390

391+
# NOTE: zero conv for CosmosControlNet
392+
if before_proj:
393+
# TODO: check hint_dim in i4
394+
self.before_proj = nn.Linear(hidden_size, hidden_size)
395+
if after_proj:
396+
self.after_proj = nn.Linear(hidden_size, hidden_size)
397+
389398
def forward(
390399
self,
391400
hidden_states: torch.Tensor,
@@ -418,7 +427,7 @@ def forward(
418427
hidden_states = hidden_states + gate * ff_output
419428

420429
if controlnet_residual is not None:
421-
# TODO: add control_context_scale ?
430+
# NOTE: this is assumed to be scaled by the controlnet
422431
hidden_states += controlnet_residual
423432

424433
return hidden_states
@@ -556,8 +565,6 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
556565
controlnet_block_every_n (`int`, *optional*):
557566
Interval between transformer blocks that should receive control residuals (for example, `7` to inject after
558567
every seventh block). Required for Cosmos Transfer2.5.
559-
n_controlnet_blocks (`int`, *optional*):
560-
The number of control net blocks. If None provided: as many as possible will be placed respecting `controlnet_block_every_n`
561568
img_context_dim (`int`, *optional*):
562569
TODO document me
563570
TODO rename?
@@ -588,7 +595,6 @@ def __init__(
588595
crossattn_proj_in_channels: int = 1024,
589596
encoder_hidden_states_channels: int = 1024,
590597
controlnet_block_every_n: Optional[int] = None,
591-
n_control_net_blocks: Optional[int] = None,
592598
img_context_dim: Optional[int] = None,
593599
) -> None:
594600
super().__init__()
@@ -744,7 +750,7 @@ def forward(
744750
n_blocks = len(self.transformer_blocks)
745751
controlnet_block_index_map = {
746752
block_idx: block_controlnet_hidden_states[idx]
747-
for idx, block_idx in list(enumerate(range(0, n_blocks, self.config.controlnet_block_every_n)))[0:self.config.n_controlnet_blocks]
753+
for idx, block_idx in list(enumerate(range(0, n_blocks, self.config.controlnet_block_every_n)))
748754
}
749755

750756
# 5. Transformer blocks

0 commit comments

Comments
 (0)