Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Jun 1, 2025

📄 9% (0.09x) speedup for HunyuanVideoResnetBlockCausal3D.forward in src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

⏱️ Runtime : 2.60 milliseconds 2.38 milliseconds (best of 38 runs)

📝 Explanation and details

Main changes explained:

  • Removed unnecessary .contiguous(): It is only needed if downstream ops require contiguous memory, but most standard layers in PyTorch don't. Keeping inputs as-is avoids a possible memory reallocation.
  • In-place ops: Used torch.add() for addition instead of +, which gives an opportunity for memory reuse. In-place version via out= is unsafe for autograd here, so left non-inplace but direct function call to avoid some Python op overhead.
  • Removed redundant else-blocks and preserved streamlined logic.
  • Kept activation and normalization tightly chained as in the original; fused norm+act via eliminating unnecessary assignment lines. No further fusion possible since we're using standard PyTorch layers.
  • Did not micro-optimize for GroupNorm/Dropout/Conv as they are likely custom implementations or critical ops; speed here is dictated by their PyTorch/CUDA/implementation.
  • Kept the signature and logic identical. All function results and edge cases unchanged.

This rewrite preserves correctness while minimizing Python overhead, especially for high-performance situations where the underlying operators will still dominate runtime. For further acceleration, tuning the lower-level convolution implementation, or using mixed precision (autocast), or torch.compile/tracing, or fusing custom norm+act+conv would be necessary.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 12 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 92.3%
🌀 Generated Regression Tests Details
from typing import Optional

# imports
import pytest  # used for our unit tests
import torch
import torch.nn as nn
import torch.utils.checkpoint
from src.diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import \
    HunyuanVideoResnetBlockCausal3D

# function to test
# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Dummy CausalConv3d for testing purposes
class HunyuanVideoCausalConv3d(nn.Conv3d):
    # For test purposes, this is just a regular Conv3d with the same interface
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        # Accept int or tuple for kernel_size, stride, padding
        super().__init__(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=True,
        )

ACT2CLS = {
    "swish": nn.SiLU,
    "silu": nn.SiLU,
    "mish": nn.Mish,
    "gelu": nn.GELU,
    "relu": nn.ReLU,
}

def get_activation(act_fn: str) -> nn.Module:
    """Helper function to get activation function from string.

    Args:
        act_fn (str): Name of activation function.

    Returns:
        nn.Module: Activation function.
    """
    act_fn = act_fn.lower()
    if act_fn in ACT2CLS:
        return ACT2CLS[act_fn]()
    else:
        raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
from src.diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import \
    HunyuanVideoResnetBlockCausal3D

# ================== UNIT TESTS ==================

# ----------- Basic Test Cases -----------

def test_forward_basic_identity_channels():
    """
    Test forward with in_channels == out_channels, no dropout.
    Should not use conv_shortcut.
    """
    batch, c, t, h, w = 2, 4, 5, 6, 7
    x = torch.randn(batch, c, t, h, w)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c, out_channels=c, dropout=0.0, groups=2)
    block.eval()
    out = block(x)

def test_forward_basic_channel_change():
    """
    Test forward with in_channels != out_channels.
    Should use conv_shortcut.
    """
    batch, c_in, c_out, t, h, w = 2, 3, 5, 4, 6, 6
    x = torch.randn(batch, c_in, t, h, w)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c_in, out_channels=c_out, dropout=0.0, groups=1)
    block.eval()
    out = block(x)


def test_forward_basic_activation_variants():
    """
    Test forward with different non_linearity values.
    """
    batch, c, t, h, w = 1, 6, 2, 2, 2
    for nonlin in ["swish", "silu", "mish", "gelu", "relu"]:
        x = torch.randn(batch, c, t, h, w)
        block = HunyuanVideoResnetBlockCausal3D(in_channels=c, non_linearity=nonlin, groups=2)
        block.eval()
        out = block(x)

def test_forward_basic_groupnorm_groups_divisor():
    """
    Test forward with groups that evenly divide in_channels.
    """
    batch, c, t, h, w = 1, 8, 2, 2, 2
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c, groups=4)
    x = torch.randn(batch, c, t, h, w)
    out = block(x)

# ----------- Edge Test Cases -----------

def test_forward_edge_single_element():
    """
    Test forward with a single element in each dimension.
    """
    x = torch.randn(1, 1, 1, 1, 1)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=1, groups=1)
    out = block(x)

def test_forward_edge_large_group_number():
    """
    Test forward with groups == in_channels (LayerNorm-like).
    """
    c = 8
    x = torch.randn(1, c, 2, 2, 2)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c, groups=c)
    out = block(x)

def test_forward_edge_one_group():
    """
    Test forward with groups == 1 (InstanceNorm-like).
    """
    c = 5
    x = torch.randn(1, c, 2, 2, 2)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c, groups=1)
    out = block(x)

def test_forward_edge_invalid_activation():
    """
    Test forward with invalid non_linearity string.
    Should raise ValueError.
    """
    with pytest.raises(ValueError):
        HunyuanVideoResnetBlockCausal3D(in_channels=4, non_linearity="not_an_activation")

def test_forward_edge_invalid_groupnorm():
    """
    Test forward with groups not dividing in_channels.
    Should raise ValueError from GroupNorm.
    """
    # GroupNorm will raise ValueError if num_channels % groups != 0
    with pytest.raises(ValueError):
        HunyuanVideoResnetBlockCausal3D(in_channels=5, groups=3)

def test_forward_edge_zero_input():
    """
    Test forward with all zeros input.
    Output should not be all zeros due to bias in conv layers.
    """
    x = torch.zeros(1, 4, 2, 2, 2)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=4, groups=2)
    out = block(x)

def test_forward_edge_different_shapes():
    """
    Test forward with non-square spatial and temporal dimensions.
    """
    x = torch.randn(2, 4, 3, 5, 7)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=4, groups=2)
    out = block(x)

def test_forward_edge_float16():
    """
    Test forward with float16 input.
    """
    x = torch.randn(1, 4, 2, 2, 2).half()
    block = HunyuanVideoResnetBlockCausal3D(in_channels=4, groups=2)
    block = block.half()
    out = block(x)

def test_forward_edge_float64():
    """
    Test forward with float64 input.
    """
    x = torch.randn(1, 4, 2, 2, 2, dtype=torch.float64)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=4, groups=2)
    block = block.double()
    out = block(x)

def test_forward_edge_gradcheck():
    """
    Test forward/backward for grad correctness using double precision.
    """
    c = 2
    x = torch.randn(1, c, 2, 2, 2, dtype=torch.float64, requires_grad=True)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c, groups=2)
    block = block.double()
    def func(input):
        return block(input)
    # gradcheck expects input tuple
    torch.autograd.gradcheck(func, (x,))

# ----------- Large Scale Test Cases -----------



def test_forward_large_channels():
    """
    Test forward with large number of channels, but within 100MB.
    """
    # 1*64*2*8*8*4B = 32KB
    batch, c, t, h, w = 1, 64, 2, 8, 8
    x = torch.randn(batch, c, t, h, w)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c, groups=8)
    out = block(x)

def test_forward_large_all_dims():
    """
    Test forward with all dimensions moderately large, but <100MB.
    """
    # 2*8*4*8*8*4B = 16KB
    batch, c, t, h, w = 2, 8, 4, 8, 8
    x = torch.randn(batch, c, t, h, w)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c, groups=4)
    out = block(x)

def test_forward_large_out_channels():
    """
    Test forward with in_channels != out_channels, both large.
    """
    batch, c_in, c_out, t, h, w = 1, 32, 64, 2, 8, 8
    x = torch.randn(batch, c_in, t, h, w)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c_in, out_channels=c_out, groups=8)
    out = block(x)

def test_forward_large_multiple_runs_consistency():
    """
    Test that repeated runs with the same input and block in eval mode produce the same output.
    """
    batch, c, t, h, w = 2, 4, 2, 8, 8
    x = torch.randn(batch, c, t, h, w)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=c, groups=2)
    block.eval()
    out1 = block(x)
    out2 = block(x)

# ----------- Miscellaneous Robustness -----------

def test_forward_requires_grad_preserved():
    """
    Test that requires_grad is preserved in output if input requires_grad.
    """
    x = torch.randn(1, 4, 2, 2, 2, requires_grad=True)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=4, groups=2)
    out = block(x)

def test_forward_backward_pass():
    """
    Test that backward pass works without error.
    """
    x = torch.randn(1, 4, 2, 2, 2, requires_grad=True)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=4, groups=2)
    out = block(x)
    loss = out.sum()
    loss.backward()
    # Check that gradients are not None for parameters
    for param in block.parameters():
        pass

def test_forward_different_devices():
    """
    Test forward on CPU and (if available) CUDA.
    """
    x = torch.randn(1, 4, 2, 2, 2)
    block = HunyuanVideoResnetBlockCausal3D(in_channels=4, groups=2)
    out_cpu = block(x)
    if torch.cuda.is_available():
        block_cuda = HunyuanVideoResnetBlockCausal3D(in_channels=4, groups=2).cuda()
        x_cuda = x.cuda()
        out_cuda = block_cuda(x_cuda)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

import pytest  # used for our unit tests
import torch
import torch.nn as nn
from src.diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import \
    HunyuanVideoResnetBlockCausal3D

# function to test
# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ACT2CLS = {
    "swish": nn.SiLU,
    "silu": nn.SiLU,
    "mish": nn.Mish,
    "gelu": nn.GELU,
    "relu": nn.ReLU,
}


def get_activation(act_fn: str) -> nn.Module:
    """Helper function to get activation function from string.

    Args:
        act_fn (str): Name of activation function.

    Returns:
        nn.Module: Activation function.
    """

    act_fn = act_fn.lower()
    if act_fn in ACT2CLS:
        return ACT2CLS[act_fn]()
    else:
        raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")


class HunyuanVideoCausalConv3d(nn.Conv3d):
    """
    A 3D convolution layer with causal padding in the temporal dimension.
    For simplicity, this implementation pads only the temporal dimension (dim=2).
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        # Accept int or tuple for kernel_size/stride/padding
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size, kernel_size)
        if isinstance(stride, int):
            stride = (stride, stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding, padding)
        super().__init__(in_channels, out_channels, kernel_size, stride, 0, bias=True)

        self._causal_padding = (kernel_size[0] - 1, 0)  # Only pad left (past) on temporal axis

    def forward(self, x):
        # Pad only the temporal dimension (dim=2)
        pad = [0, 0, 0, 0, self._causal_padding[0], self._causal_padding[1]]  # (D, H, W) => (t, h, w)
        x = torch.nn.functional.pad(x, pad)
        return super().forward(x)
from src.diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import \
    HunyuanVideoResnetBlockCausal3D

# ------------------- UNIT TESTS FOR forward -------------------

# Helper function to create input tensors
def make_input(batch, channels, frames, height, width, fill_value=None):
    shape = (batch, channels, frames, height, width)
    if fill_value is not None:
        return torch.full(shape, fill_value)
    else:
        return torch.randn(shape)

# ------------------- BASIC TEST CASES -------------------





def test_forward_basic_groupnorm():
    # Test with different group numbers
    for groups in [1, 2]:
        model = HunyuanVideoResnetBlockCausal3D(in_channels=4, out_channels=4, groups=groups)
        x = make_input(2, 4, 4, 4, 4)
        codeflash_output = model.forward(x); y = codeflash_output

# ------------------- EDGE TEST CASES -------------------




def test_forward_edge_large_groupnorm():
    # GroupNorm groups equals channels (LayerNorm-like)
    model = HunyuanVideoResnetBlockCausal3D(in_channels=8, out_channels=8, groups=8)
    x = make_input(1, 8, 3, 3, 3)
    codeflash_output = model.forward(x); y = codeflash_output

def test_forward_edge_invalid_activation():
    # Should raise ValueError for unknown activation
    with pytest.raises(ValueError):
        HunyuanVideoResnetBlockCausal3D(in_channels=2, out_channels=2, non_linearity="notarealactivation")

def test_forward_edge_invalid_groups():
    # Should raise ValueError if groups does not divide channels
    with pytest.raises(ValueError):
        HunyuanVideoResnetBlockCausal3D(in_channels=5, out_channels=5, groups=3)







def test_forward_large_channels():
    # Large number of channels, but within 100MB
    batch = 1
    channels = 32
    frames = 3
    height = 8
    width = 8
    model = HunyuanVideoResnetBlockCausal3D(in_channels=channels, out_channels=channels)
    x = make_input(batch, channels, frames, height, width)
    codeflash_output = model.forward(x); y = codeflash_output



def test_forward_large_groups():
    # Large number of groups in GroupNorm
    batch = 2
    channels = 16
    frames = 3
    height = 8
    width = 8
    model = HunyuanVideoResnetBlockCausal3D(in_channels=channels, out_channels=channels, groups=8)
    x = make_input(batch, channels, frames, height, width)
    codeflash_output = model.forward(x); y = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-HunyuanVideoResnetBlockCausal3D.forward-mbdzwar4 and push.

Codeflash

**Main changes explained:**
- **Removed unnecessary `.contiguous()`:** It is only needed if downstream ops require contiguous memory, but most standard layers in PyTorch don't. Keeping inputs as-is avoids a possible memory reallocation.
- **In-place ops:** Used `torch.add()` for addition instead of `+`, which gives an opportunity for memory reuse. In-place version via `out=` is unsafe for autograd here, so left non-inplace but direct function call to avoid some Python op overhead.
- **Removed redundant else-blocks and preserved streamlined logic.**
- **Kept activation and normalization tightly chained as in the original; fused norm+act via eliminating unnecessary assignment lines.** No further fusion possible since we're using standard PyTorch layers.
- **Did not micro-optimize for GroupNorm/Dropout/Conv as they are likely custom implementations or critical ops; speed here is dictated by their PyTorch/CUDA/implementation.**
- **Kept the signature and logic identical.** All function results and edge cases unchanged.

This rewrite preserves correctness while minimizing Python overhead, especially for high-performance situations where the underlying operators will still dominate runtime. For further acceleration, tuning the lower-level convolution implementation, or using mixed precision (`autocast`), or torch.compile/tracing, or fusing custom norm+act+conv would be necessary.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Jun 1, 2025
@codeflash-ai codeflash-ai bot requested a review from aseembits93 June 1, 2025 18:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants