1010from ..modeling_utils import ModelMixin
1111from ..transformers .transformer_cosmos import (
1212 CosmosPatchEmbed ,
13+ CosmosTransformerBlock ,
1314)
1415from .controlnet import zero_module
1516
1617
1718logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
1819
1920
20- class CosmosControlNetBlock (nn .Module ):
21- def __init__ (self , hidden_size : int ):
22- super ().__init__ ()
23- self .proj = zero_module (nn .Linear (hidden_size , hidden_size , bias = True ))
24-
25- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
26- return self .proj (hidden_states )
27-
28-
2921# TODO(migmartin): implement me
3022# see i4/projects/cosmos/transfer2/networks/minimal_v4_lvg_dit_control_vace.py
3123class CosmosControlNetModel (ModelMixin , ConfigMixin , FromOriginalModelMixin ):
3224 r"""
33- Minimal ControlNet for Cosmos Transfer2.5.
34-
35- This module projects encoded control latents into per-block residuals aligned with the
36- `CosmosTransformer3DModel` hidden size. All projections are zero-initialized so the ControlNet
37- starts neutral by default.
25+ ControlNet for Cosmos Transfer2.5.
3826 """
3927
4028 @register_to_config
4129 def __init__ (
4230 self ,
31+ n_controlnet_blocks : int = 4 ,
4332 in_channels : int = 16 ,
33+ model_channels : int = 2048 ,
4434 num_attention_heads : int = 32 ,
4535 attention_head_dim : int = 128 ,
46- num_layers : int = 4 ,
36+ mlp_ratio : float = 4.0 ,
37+ text_embed_dim : int = 1024 ,
38+ adaln_lora_dim : int = 256 ,
4739 patch_size : Tuple [int , int , int ] = (1 , 2 , 2 ),
48- control_block_indices : Tuple [int , ...] = (6 , 13 , 20 , 27 ),
4940 ):
5041 super ().__init__ ()
51- hidden_size = num_attention_heads * attention_head_dim
52-
53- self .patch_embed = CosmosPatchEmbed (in_channels , hidden_size , patch_size , bias = False )
42+ self .patch_embed = CosmosPatchEmbed (in_channels , model_channels , patch_size , bias = False )
5443 self .control_blocks = nn .ModuleList (
55- CosmosControlNetBlock (hidden_size ) for _ in range (num_layers )
44+ [
45+ CosmosTransformerBlock (
46+ num_attention_heads = num_attention_heads ,
47+ attention_head_dim = attention_head_dim ,
48+ cross_attention_dim = text_embed_dim ,
49+ mlp_ratio = mlp_ratio ,
50+ adaln_lora_dim = adaln_lora_dim ,
51+ qk_norm = "rms_norm" ,
52+ out_bias = False ,
53+ img_context = True ,
54+ before_proj = (block_idx == 0 ),
55+ after_proj = True ,
56+ )
57+ for block_idx in range (n_controlnet_blocks )
58+ ]
5659 )
5760
5861 def _expand_conditioning_scale (self , conditioning_scale : Union [float , List [float ]]) -> List [float ]:
@@ -61,7 +64,7 @@ def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float
6164 else :
6265 scales = [conditioning_scale ] * len (self .control_blocks )
6366
64- if len (scales ) != len (self .control_blocks ):
67+ if len (scales ) < len (self .control_blocks ):
6568 logger .warning (
6669 "Received %d control scales, but control network defines %d blocks. "
6770 "Scales will be trimmed or repeated to match." ,
@@ -75,16 +78,25 @@ def forward(
7578 self ,
7679 hidden_states : torch .Tensor ,
7780 controlnet_cond : torch .Tensor ,
78- timestep : Optional [torch .Tensor ] = None ,
79- encoder_hidden_states : Optional [torch .Tensor ] = None ,
8081 conditioning_scale : Union [float , List [float ]] = 1.0 ,
81- return_dict : bool = True ,
8282 ) -> List [torch .Tensor ]:
83- del hidden_states , timestep , encoder_hidden_states # not used in this minimal control path
84-
8583 control_hidden_states = self .patch_embed (controlnet_cond )
8684 control_hidden_states = control_hidden_states .flatten (1 , 3 )
8785
8886 scales = self ._expand_conditioning_scale (conditioning_scale )
89- control_residuals = tuple (block (control_hidden_states ) * scale for block , scale in zip (self .control_blocks , scales ))
90- return control_residuals
87+ x = hidden_states
88+
89+ # NOTE: args to block
90+ # hidden_states: torch.Tensor,
91+ # encoder_hidden_states: torch.Tensor,
92+ # embedded_timestep: torch.Tensor,
93+ # temb: Optional[torch.Tensor] = None,
94+ # image_rotary_emb: Optional[torch.Tensor] = None,
95+ # extra_pos_emb: Optional[torch.Tensor] = None,
96+ # attention_mask: Optional[torch.Tensor] = None,
97+ # controlnet_residual: Optional[torch.Tensor] = None,
98+ result = []
99+ for block , scale in zip (self .control_blocks , scales ):
100+ x = block (x )
101+ result .append (x * scale )
102+ return result
0 commit comments