diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 608be6b70277..f42a4a6d667b 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -1575,23 +1575,32 @@ def from_modules( ): base_to_ctrl = ctrl_midblock.base_to_ctrl ctrl_to_base = ctrl_midblock.ctrl_to_base - ctrl_midblock = ctrl_midblock.midblock + ctrl_midblock_ref = ctrl_midblock.midblock - # get params - def get_first_cross_attention(midblock): - return midblock.attentions[0].transformer_blocks[0].attn2 + # -- Helper: Batch extract attention/other attributes with local vars to avoid attribute chain overhead + + # Cache frequently used objects + base_att = base_midblock.attentions[0] + ctrl_att = ctrl_midblock_ref.attentions[0] + base_resnet = base_midblock.resnets[0] + ctrl_resnet = ctrl_midblock_ref.resnets[0] + + # Store attention blocks for reuse + base_attn0 = base_att.transformer_blocks[0].attn2 + ctrl_attn0 = ctrl_att.transformer_blocks[0].attn2 + # Prefetch attributes, minimize repeated getattr/dot-chain calls base_channels = ctrl_to_base.out_channels ctrl_channels = ctrl_to_base.in_channels - transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) - temb_channels = base_midblock.resnets[0].time_emb_proj.in_features - num_groups = base_midblock.resnets[0].norm1.num_groups - ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups - base_num_attention_heads = get_first_cross_attention(base_midblock).heads - ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads - cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim - upcast_attention = get_first_cross_attention(base_midblock).upcast_attention - use_linear_projection = base_midblock.attentions[0].use_linear_projection + transformer_layers_per_block = len(base_att.transformer_blocks) + temb_channels = base_resnet.time_emb_proj.in_features + num_groups = base_resnet.norm1.num_groups + ctrl_num_groups = ctrl_resnet.norm1.num_groups + base_num_attention_heads = base_attn0.heads + ctrl_num_attention_heads = ctrl_attn0.heads + cross_attention_dim = base_attn0.cross_attention_dim + upcast_attention = base_attn0.upcast_attention + use_linear_projection = base_att.use_linear_projection # create model model = cls( @@ -1611,7 +1620,7 @@ def get_first_cross_attention(midblock): # load weights model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict()) model.base_midblock.load_state_dict(base_midblock.state_dict()) - model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict()) + model.ctrl_midblock.load_state_dict(ctrl_midblock_ref.state_dict()) model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict()) return model