Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/diffusers/models/controlnets/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,23 +1575,32 @@ def from_modules(
):
base_to_ctrl = ctrl_midblock.base_to_ctrl
ctrl_to_base = ctrl_midblock.ctrl_to_base
ctrl_midblock = ctrl_midblock.midblock
ctrl_midblock_ref = ctrl_midblock.midblock

# get params
def get_first_cross_attention(midblock):
return midblock.attentions[0].transformer_blocks[0].attn2
# -- Helper: Batch extract attention/other attributes with local vars to avoid attribute chain overhead

# Cache frequently used objects
base_att = base_midblock.attentions[0]
ctrl_att = ctrl_midblock_ref.attentions[0]
base_resnet = base_midblock.resnets[0]
ctrl_resnet = ctrl_midblock_ref.resnets[0]

# Store attention blocks for reuse
base_attn0 = base_att.transformer_blocks[0].attn2
ctrl_attn0 = ctrl_att.transformer_blocks[0].attn2

# Prefetch attributes, minimize repeated getattr/dot-chain calls
base_channels = ctrl_to_base.out_channels
ctrl_channels = ctrl_to_base.in_channels
transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks)
temb_channels = base_midblock.resnets[0].time_emb_proj.in_features
num_groups = base_midblock.resnets[0].norm1.num_groups
ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups
base_num_attention_heads = get_first_cross_attention(base_midblock).heads
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
use_linear_projection = base_midblock.attentions[0].use_linear_projection
transformer_layers_per_block = len(base_att.transformer_blocks)
temb_channels = base_resnet.time_emb_proj.in_features
num_groups = base_resnet.norm1.num_groups
ctrl_num_groups = ctrl_resnet.norm1.num_groups
base_num_attention_heads = base_attn0.heads
ctrl_num_attention_heads = ctrl_attn0.heads
cross_attention_dim = base_attn0.cross_attention_dim
upcast_attention = base_attn0.upcast_attention
use_linear_projection = base_att.use_linear_projection

# create model
model = cls(
Expand All @@ -1611,7 +1620,7 @@ def get_first_cross_attention(midblock):
# load weights
model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict())
model.base_midblock.load_state_dict(base_midblock.state_dict())
model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict())
model.ctrl_midblock.load_state_dict(ctrl_midblock_ref.state_dict())
model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict())

return model
Expand Down