Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Jun 1, 2025

📄 16% (0.16x) speedup for ControlNetXSCrossAttnMidBlock2D.from_modules in src/diffusers/models/controlnets/controlnet_xs.py

⏱️ Runtime : 4.17 microseconds 3.61 microseconds (best of 6 runs)

📝 Explanation and details

Key Optimizations.

  • Reduced attribute traversal: The hottest lines are repeated attribute-chain traversals into deep modules. These are batched (fetched/cached once per call) using local variables: e.g., base_att = base_midblock.attentions[0] etc.
  • Eliminated repeated get_first_cross_attention calls: All attributes of the same attention block are grabbed once and stored in variables for reuse.
  • The profile showed explicit function calls (like get_first_cross_attention) were slow due to repeated traversals; these are now done only once.
  • No changes in function signatures, return values, or behavior. Comments are left unchanged where logic is identical, but some in the classmethod are trimmed as the slow point was pure property access.
  • Type hint for "MidBlockControlNetXSAdapter" kept as string for compatibility and to avoid codebase inference; this is unchanged.

This reduces the number of (expensive) sequential attribute resolutions, which line profiling showed dominate runtime, especially for large objects in model graphs.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 6 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 85.7%
🌀 Generated Regression Tests Details
from math import gcd
from typing import Optional

# imports
import pytest  # used for our unit tests
import torch
from src.diffusers.models.controlnets.controlnet_xs import \
    ControlNetXSCrossAttnMidBlock2D
from torch import nn

# function to test
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Minimal stub for zero_module
def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

# Minimal stub for make_zero_conv
def make_zero_conv(in_channels, out_channels=None):
    if out_channels is None:
        out_channels = in_channels
    return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))

# Minimal stub for find_largest_factor
def find_largest_factor(number, max_factor):
    factor = max_factor
    if factor >= number:
        return number
    while factor != 0:
        residual = number % factor
        if residual == 0:
            return factor
        factor -= 1

# Minimal stub for cross-attention
class DummyCrossAttention(nn.Module):
    def __init__(self, heads, cross_attention_dim, upcast_attention=False):
        super().__init__()
        self.heads = heads
        self.cross_attention_dim = cross_attention_dim
        self.upcast_attention = upcast_attention

# Minimal stub for transformer block
class DummyTransformerBlock(nn.Module):
    def __init__(self, heads, cross_attention_dim, upcast_attention=False):
        super().__init__()
        self.attn2 = DummyCrossAttention(heads, cross_attention_dim, upcast_attention)

# Minimal stub for attention block
class DummyAttentionBlock(nn.Module):
    def __init__(self, heads, cross_attention_dim, num_layers, use_linear_projection, upcast_attention=False):
        super().__init__()
        self.transformer_blocks = nn.ModuleList(
            [DummyTransformerBlock(heads, cross_attention_dim, upcast_attention) for _ in range(num_layers)]
        )
        self.use_linear_projection = use_linear_projection

# Minimal stub for resnet block
class DummyResnetBlock(nn.Module):
    def __init__(self, in_features, num_groups):
        super().__init__()
        self.time_emb_proj = nn.Linear(in_features, in_features)
        self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=in_features)

# Minimal stub for UNetMidBlock2DCrossAttn
class UNetMidBlock2DCrossAttn(nn.Module):
    def __init__(
        self,
        transformer_layers_per_block,
        in_channels,
        temb_channels,
        resnet_groups,
        cross_attention_dim,
        num_attention_heads,
        use_linear_projection,
        upcast_attention=False,
        out_channels=None,
    ):
        super().__init__()
        self.attentions = nn.ModuleList([
            DummyAttentionBlock(
                num_attention_heads,
                cross_attention_dim,
                transformer_layers_per_block,
                use_linear_projection,
                upcast_attention=upcast_attention,
            )
        ])
        self.resnets = nn.ModuleList([
            DummyResnetBlock(temb_channels, resnet_groups)
        ])
        self.in_channels = in_channels
        self.out_channels = out_channels if out_channels is not None else in_channels

    def state_dict(self, *args, **kwargs):
        # Return dummy state dict for testing
        return super().state_dict(*args, **kwargs)

    def load_state_dict(self, state_dict, strict=True):
        # Accept any state dict for testing
        return super().load_state_dict(state_dict, strict=strict)

# Dummy adapter class for ctrl_midblock
class MidBlockControlNetXSAdapter:
    def __init__(self, base_to_ctrl, ctrl_to_base, midblock):
        self.base_to_ctrl = base_to_ctrl
        self.ctrl_to_base = ctrl_to_base
        self.midblock = midblock
from src.diffusers.models.controlnets.controlnet_xs import \
    ControlNetXSCrossAttnMidBlock2D

# ------------------- UNIT TESTS -------------------

# Helper to create dummy modules for test
def create_dummy_modules(
    base_channels=4,
    ctrl_channels=8,
    temb_channels=16,
    norm_num_groups=2,
    ctrl_max_norm_num_groups=2,
    transformer_layers_per_block=1,
    base_num_attention_heads=1,
    ctrl_num_attention_heads=1,
    cross_attention_dim=32,
    upcast_attention=False,
    use_linear_projection=True,
):
    # Create base_midblock
    base_midblock = UNetMidBlock2DCrossAttn(
        transformer_layers_per_block=transformer_layers_per_block,
        in_channels=base_channels,
        temb_channels=temb_channels,
        resnet_groups=norm_num_groups,
        cross_attention_dim=cross_attention_dim,
        num_attention_heads=base_num_attention_heads,
        use_linear_projection=use_linear_projection,
        upcast_attention=upcast_attention,
    )
    # Create ctrl_midblock
    ctrl_midblock_inner = UNetMidBlock2DCrossAttn(
        transformer_layers_per_block=transformer_layers_per_block,
        in_channels=ctrl_channels + base_channels,
        out_channels=ctrl_channels,
        temb_channels=temb_channels,
        resnet_groups=find_largest_factor(
            gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups
        ),
        cross_attention_dim=cross_attention_dim,
        num_attention_heads=ctrl_num_attention_heads,
        use_linear_projection=use_linear_projection,
        upcast_attention=upcast_attention,
    )
    base_to_ctrl = make_zero_conv(base_channels, base_channels)
    ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
    ctrl_midblock = MidBlockControlNetXSAdapter(base_to_ctrl, ctrl_to_base, ctrl_midblock_inner)
    return base_midblock, ctrl_midblock

# ------------------- BASIC TESTS -------------------











def test_from_modules_invalid_input_raises():
    # Test that passing wrong types raises AttributeError
    with pytest.raises(AttributeError):
        ControlNetXSCrossAttnMidBlock2D.from_modules("not_a_module", "not_a_module")

# ------------------- LARGE SCALE TESTS -------------------







from math import gcd
from typing import Optional

# imports
import pytest  # used for our unit tests
import torch
from src.diffusers.models.controlnets.controlnet_xs import \
    ControlNetXSCrossAttnMidBlock2D
from torch import nn

# function to test
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# --- Minimal stub implementations for dependencies ---

def zero_module(module):
    # Set all parameters to zero for deterministic behavior
    for p in module.parameters():
        nn.init.constant_(p, 0)
    return module

def make_zero_conv(in_channels, out_channels=None):
    if out_channels is None:
        out_channels = in_channels
    return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))

def find_largest_factor(number, max_factor):
    factor = max_factor
    if factor >= number:
        return number
    while factor != 0:
        residual = number % factor
        if residual == 0:
            return factor
        factor -= 1

# Stub for the attention block used in the midblock
class DummyAttn:
    def __init__(self, heads, cross_attention_dim, upcast_attention):
        self.heads = heads
        self.cross_attention_dim = cross_attention_dim
        self.upcast_attention = upcast_attention

# Stub for transformer block
class DummyTransformerBlock:
    def __init__(self, attn2):
        self.attn2 = attn2

# Stub for attention wrapper
class DummyAttention:
    def __init__(self, transformer_blocks, use_linear_projection):
        self.transformer_blocks = transformer_blocks
        self.use_linear_projection = use_linear_projection

# Stub for resnet block
class DummyResnet:
    def __init__(self, in_features, num_groups):
        # time_emb_proj is a dummy linear layer
        self.time_emb_proj = nn.Linear(in_features, in_features)
        self.norm1 = nn.GroupNorm(num_groups, in_features)

# Minimal stub for UNetMidBlock2DCrossAttn
class UNetMidBlock2DCrossAttn(nn.Module):
    def __init__(
        self,
        transformer_layers_per_block,
        in_channels,
        temb_channels,
        resnet_groups,
        cross_attention_dim,
        num_attention_heads,
        use_linear_projection,
        upcast_attention,
        out_channels=None
    ):
        super().__init__()
        self.in_channels = in_channels
        self.temb_channels = temb_channels
        self.resnet_groups = resnet_groups
        self.cross_attention_dim = cross_attention_dim
        self.num_attention_heads = num_attention_heads
        self.use_linear_projection = use_linear_projection
        self.upcast_attention = upcast_attention
        self.out_channels = out_channels if out_channels is not None else in_channels

        # Simulate resnets and attentions as lists with dummy objects
        self.resnets = [DummyResnet(temb_channels, resnet_groups)]
        attn2 = DummyAttn(num_attention_heads, cross_attention_dim, upcast_attention)
        transformer_blocks = [DummyTransformerBlock(attn2) for _ in range(transformer_layers_per_block)]
        self.attentions = [DummyAttention(transformer_blocks, use_linear_projection)]

    def state_dict(self):
        # Return a dummy state dict
        return {"dummy_param": torch.tensor(1.)}

    def load_state_dict(self, state_dict):
        # Accept any state dict for testing
        pass

# Adapter class for ctrl_midblock in from_modules
class MidBlockControlNetXSAdapter:
    def __init__(self, base_to_ctrl, ctrl_to_base, midblock):
        self.base_to_ctrl = base_to_ctrl
        self.ctrl_to_base = ctrl_to_base
        self.midblock = midblock
from src.diffusers.models.controlnets.controlnet_xs import \
    ControlNetXSCrossAttnMidBlock2D

# --- Unit Tests ---

# Helper to make a dummy midblock adapter
def make_dummy_adapter(
    base_channels=8, ctrl_channels=12, temb_channels=16, 
    transformer_layers_per_block=1, norm_num_groups=4, ctrl_max_norm_num_groups=4,
    base_num_attention_heads=2, ctrl_num_attention_heads=3,
    cross_attention_dim=32, upcast_attention=False, use_linear_projection=True
):
    base_to_ctrl = make_zero_conv(base_channels, base_channels)
    ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
    midblock = UNetMidBlock2DCrossAttn(
        transformer_layers_per_block=transformer_layers_per_block,
        in_channels=ctrl_channels + base_channels,
        out_channels=ctrl_channels,
        temb_channels=temb_channels,
        resnet_groups=find_largest_factor(
            gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups
        ),
        cross_attention_dim=cross_attention_dim,
        num_attention_heads=ctrl_num_attention_heads,
        use_linear_projection=use_linear_projection,
        upcast_attention=upcast_attention,
    )
    return MidBlockControlNetXSAdapter(base_to_ctrl, ctrl_to_base, midblock)

# Helper to make a dummy base midblock
def make_dummy_base_midblock(
    base_channels=8, temb_channels=16, transformer_layers_per_block=1, 
    norm_num_groups=4, base_num_attention_heads=2,
    cross_attention_dim=32, upcast_attention=False, use_linear_projection=True
):
    return UNetMidBlock2DCrossAttn(
        transformer_layers_per_block=transformer_layers_per_block,
        in_channels=base_channels,
        temb_channels=temb_channels,
        resnet_groups=norm_num_groups,
        cross_attention_dim=cross_attention_dim,
        num_attention_heads=base_num_attention_heads,
        use_linear_projection=use_linear_projection,
        upcast_attention=upcast_attention,
    )

# 1. Basic Test Cases










def test_from_modules_edge_invalid_adapter():
    """
    Edge: Test that missing attributes in adapter raises AttributeError.
    """
    base_midblock = make_dummy_base_midblock()
    # Adapter missing ctrl_to_base
    class BadAdapter:
        def __init__(self):
            self.base_to_ctrl = make_zero_conv(8, 8)
            self.midblock = make_dummy_base_midblock()
    bad_adapter = BadAdapter()
    with pytest.raises(AttributeError):
        ControlNetXSCrossAttnMidBlock2D.from_modules(base_midblock, bad_adapter)

To edit these changes git checkout codeflash/optimize-ControlNetXSCrossAttnMidBlock2D.from_modules-mbdtqg3s and push.

Codeflash

### Key Optimizations.

- **Reduced attribute traversal**: The hottest lines are repeated attribute-chain traversals into deep modules. These are batched (fetched/cached once per call) using local variables: e.g., `base_att = base_midblock.attentions[0]` etc.
- **Eliminated repeated `get_first_cross_attention` calls**: All attributes of the *same* attention block are grabbed once and stored in variables for reuse.
- The profile showed explicit function calls (like `get_first_cross_attention`) were slow due to repeated traversals; these are now done only once.
- No changes in function signatures, return values, or behavior. Comments are left unchanged where logic is identical, but some in the classmethod are trimmed as the slow point was pure property access.
- Type hint for `"MidBlockControlNetXSAdapter"` kept as string for compatibility and to avoid codebase inference; this is unchanged.

This reduces the number of (expensive) sequential attribute resolutions, which line profiling showed dominate runtime, especially for large objects in model graphs.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Jun 1, 2025
@codeflash-ai codeflash-ai bot requested a review from aseembits93 June 1, 2025 15:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants