Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Jun 1, 2025

📄 51% (0.51x) speedup for FreeInitMixin._apply_freq_filter in src/diffusers/pipelines/free_init_utils.py

⏱️ Runtime : 32.0 milliseconds 21.2 milliseconds (best of 113 runs)

📝 Explanation and details

Here’s an optimized version that reduces memory footprint, in-place where safe, avoids unnecessarily repeated calculations, and combines chained calls. The function logic and return value are unchanged.

Key Optimizations:

  • Fused the FFT and shift operations for both x and noise.
  • Used torch.sub instead of 1 - low_pass_filter for better performance and clarity.
  • Performed multiplications and addition in-place (mul_, add_) to reduce memory overhead and temporary tensors.
  • Used the same tensor (x_freq) as the accumulator for the frequency-domain mixture.
  • Reused dims variable to avoid repeated tuple construction.
  • No extra allocations beyond what is mathematically necessary; all outputs and logic are unchanged.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 108 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests Details
import math

# imports
import pytest  # used for our unit tests
import torch
import torch.fft as fft
from src.diffusers.pipelines.free_init_utils import FreeInitMixin

# unit tests

class Dummy(FreeInitMixin):
    pass

# Helper for approximate equality (since FFT/IFFT may introduce small numerical errors)
def tensors_allclose(a, b, rtol=1e-5, atol=1e-6):
    if a.shape != b.shape:
        return False
    return torch.all(torch.isclose(a, b, rtol=rtol, atol=atol)).item()

# 1. Basic Test Cases

def test_identity_low_pass_filter():
    """If low_pass_filter is all ones, output should match x exactly (noise ignored)."""
    x = torch.randn(2, 3, 4, 4)
    noise = torch.randn(2, 3, 4, 4)
    low_pass_filter = torch.ones(2, 3, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

def test_identity_high_pass_filter():
    """If low_pass_filter is all zeros, output should match noise exactly (x ignored)."""
    x = torch.randn(2, 3, 4, 4)
    noise = torch.randn(2, 3, 4, 4)
    low_pass_filter = torch.zeros(2, 3, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

def test_half_and_half_filter():
    """If low_pass_filter is 0.5 everywhere, output should be average of x and noise."""
    x = torch.randn(2, 3, 4, 4)
    noise = torch.randn(2, 3, 4, 4)
    low_pass_filter = torch.full((2, 3, 4, 4), 0.5)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output
    # Since FFT/IFFT is linear, this should be the average
    expected = (x + noise) / 2

def test_broadcastable_filter():
    """Test that low_pass_filter can be broadcasted (e.g., shape is smaller than x/noise)."""
    x = torch.randn(1, 2, 4, 4)
    noise = torch.randn(1, 2, 4, 4)
    low_pass_filter = torch.tensor([[[[1.0]], [[0.0]]]])  # shape (1,2,1,1)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

def test_real_and_complex_input():
    """Test that function works for real input and returns real output."""
    x = torch.randn(2, 3, 4, 4)
    noise = torch.randn(2, 3, 4, 4)
    low_pass_filter = torch.ones(2, 3, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

# 2. Edge Test Cases

def test_zero_input():
    """Test with all-zero x and all-zero noise (output should be all zeros)."""
    x = torch.zeros(1, 1, 4, 4)
    noise = torch.zeros(1, 1, 4, 4)
    low_pass_filter = torch.ones(1, 1, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

def test_extreme_values():
    """Test with large and small values in x and noise."""
    x = torch.full((1,1,4,4), 1e10)
    noise = torch.full((1,1,4,4), -1e10)
    low_pass_filter = torch.full((1,1,4,4), 0.5)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output
    expected = torch.zeros_like(x)

def test_non_square_input():
    """Test with non-square spatial dimensions."""
    x = torch.randn(1, 2, 3, 5)
    noise = torch.randn(1, 2, 3, 5)
    low_pass_filter = torch.ones(1, 2, 3, 5)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

def test_singleton_dimensions():
    """Test with singleton dimensions (e.g., batch size 1, channel 1)."""
    x = torch.randn(1, 1, 8, 8)
    noise = torch.randn(1, 1, 8, 8)
    low_pass_filter = torch.ones(1, 1, 8, 8)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

def test_incorrect_shapes():
    """Test that mismatched shapes raise an error."""
    x = torch.randn(1, 2, 4, 4)
    noise = torch.randn(1, 2, 4, 4)
    low_pass_filter = torch.ones(1, 1, 4, 3)  # Incorrect shape
    with pytest.raises(RuntimeError):
        Dummy()._apply_freq_filter(x, noise, low_pass_filter)

def test_non_float_input():
    """Test that integer input is promoted to float."""
    x = torch.randint(0, 10, (1, 1, 4, 4), dtype=torch.int32)
    noise = torch.randint(0, 10, (1, 1, 4, 4), dtype=torch.int32)
    low_pass_filter = torch.ones(1, 1, 4, 4)
    # Should not raise, output should be float
    codeflash_output = Dummy()._apply_freq_filter(x.float(), noise.float(), low_pass_filter); out = codeflash_output

def test_filter_out_of_bounds():
    """Test that filter values outside [0,1] are handled (should not crash, but result is linear mix)."""
    x = torch.ones(1, 1, 4, 4)
    noise = torch.zeros(1, 1, 4, 4)
    low_pass_filter = torch.full((1, 1, 4, 4), 1.5)  # >1
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output
    # Since filter > 1, result will be x * 1.5 + noise * -0.5
    expected = x * 1.5 + noise * -0.5

def test_filter_negative():
    """Test with negative filter values."""
    x = torch.ones(1, 1, 4, 4)
    noise = torch.zeros(1, 1, 4, 4)
    low_pass_filter = torch.full((1, 1, 4, 4), -0.5)  # <0
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output
    # Should be x*-0.5 + noise*1.5
    expected = x * -0.5 + noise * 1.5

def test_empty_tensor():
    """Test with empty tensors (should not crash, output should be empty)."""
    x = torch.empty(0, 2, 4, 4)
    noise = torch.empty(0, 2, 4, 4)
    low_pass_filter = torch.ones(0, 2, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

# 3. Large Scale Test Cases

def test_large_tensor_performance():
    """Test with a large tensor (but <100MB)."""
    # 1*3*128*128*2*4 bytes = 393216 bytes = 0.4 MB per tensor
    x = torch.randn(8, 3, 128, 128)
    noise = torch.randn(8, 3, 128, 128)
    low_pass_filter = torch.ones(8, 3, 128, 128)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

def test_large_broadcast_filter():
    """Test with a large tensor and broadcastable filter."""
    x = torch.randn(8, 3, 64, 64)
    noise = torch.randn(8, 3, 64, 64)
    low_pass_filter = torch.ones(1, 3, 1, 1)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output

def test_large_tensor_half_filter():
    """Test with large tensor and half filter."""
    x = torch.randn(8, 3, 128, 128)
    noise = torch.randn(8, 3, 128, 128)
    low_pass_filter = torch.full((8, 3, 128, 128), 0.5)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output
    expected = (x + noise) / 2

def test_large_tensor_random_filter():
    """Test with large tensor and random filter values in [0,1]."""
    x = torch.randn(4, 3, 128, 128)
    noise = torch.randn(4, 3, 128, 128)
    low_pass_filter = torch.rand(4, 3, 128, 128)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output
    # Check that min(out) >= min(x, noise) and max(out) <= max(x, noise)
    min_xn = torch.min(torch.stack([x, noise]), dim=0).values
    max_xn = torch.max(torch.stack([x, noise]), dim=0).values

def test_large_tensor_multi_channel():
    """Test with large tensor and many channels."""
    x = torch.randn(2, 16, 64, 64)
    noise = torch.randn(2, 16, 64, 64)
    low_pass_filter = torch.ones(2, 16, 64, 64)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, low_pass_filter); out = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

import math

# imports
import pytest  # used for our unit tests
import torch
import torch.fft as fft
from src.diffusers.pipelines.free_init_utils import FreeInitMixin

# unit tests

class Dummy(FreeInitMixin):
    pass

# Helper function to compare tensors (since we can't use torch.testing)
def tensors_close(a, b, atol=1e-5, rtol=1e-4):
    if a.shape != b.shape:
        return False
    diff = (a - b).abs()
    tol = atol + rtol * b.abs()
    return bool(torch.all(diff <= tol))

# ---------------------- BASIC TEST CASES ----------------------

def test_identity_low_pass_filter():
    # All ones low-pass filter: output should be x
    x = torch.rand(1, 2, 4, 4)
    noise = torch.rand(1, 2, 4, 4)
    lpf = torch.ones(1, 2, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_zero_low_pass_filter():
    # All zeros low-pass filter: output should be noise
    x = torch.rand(1, 2, 4, 4)
    noise = torch.rand(1, 2, 4, 4)
    lpf = torch.zeros(1, 2, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_half_half_low_pass_filter():
    # Half low frequencies from x, half from noise
    x = torch.ones(1, 1, 4, 4)
    noise = torch.zeros(1, 1, 4, 4)
    lpf = torch.zeros(1, 1, 4, 4)
    lpf[..., :2, :] = 1  # top half: from x, bottom half: from noise
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_input_and_noise_are_equal():
    # If x == noise, output should be x (regardless of filter)
    x = torch.rand(2, 3, 8, 8)
    lpf = torch.rand(2, 3, 8, 8)
    codeflash_output = Dummy()._apply_freq_filter(x, x, lpf); result = codeflash_output

def test_low_pass_filter_between_0_and_1():
    # If low_pass_filter is 0.5 everywhere, output should be a mix of x and noise
    x = torch.ones(1, 1, 4, 4)
    noise = torch.zeros(1, 1, 4, 4)
    lpf = torch.full((1, 1, 4, 4), 0.5)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

# ---------------------- EDGE TEST CASES ----------------------

def test_empty_tensor():
    # Empty input tensors
    x = torch.empty(0, 2, 4, 4)
    noise = torch.empty(0, 2, 4, 4)
    lpf = torch.ones(0, 2, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_single_element_tensor():
    # Single element
    x = torch.tensor([[[[1.0]]]])
    noise = torch.tensor([[[[2.0]]]])
    lpf = torch.tensor([[[[1.0]]]])
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output
    lpf = torch.tensor([[[[0.0]]]])
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_mismatched_shapes():
    # Mismatched shapes should raise an error
    x = torch.rand(1, 2, 4, 4)
    noise = torch.rand(1, 2, 4, 3)
    lpf = torch.ones(1, 2, 4, 4)
    with pytest.raises(RuntimeError):
        Dummy()._apply_freq_filter(x, noise, lpf)

def test_nan_and_inf_in_input():
    # Should propagate NaN and Inf
    x = torch.tensor([[[[math.nan, math.inf], [1.0, 2.0]]]])
    noise = torch.zeros_like(x)
    lpf = torch.ones_like(x)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_non_float_input():
    # Should work with integer tensors (converted to float)
    x = torch.ones(1, 1, 4, 4, dtype=torch.int32)
    noise = torch.zeros(1, 1, 4, 4, dtype=torch.int32)
    lpf = torch.ones(1, 1, 4, 4)
    # Should not raise and should return float
    codeflash_output = Dummy()._apply_freq_filter(x.float(), noise.float(), lpf); result = codeflash_output

def test_high_dimensional_input():
    # Test with 5D input
    x = torch.rand(2, 2, 3, 4, 4)
    noise = torch.rand(2, 2, 3, 4, 4)
    lpf = torch.ones(2, 2, 3, 4, 4)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

# ---------------------- LARGE SCALE TEST CASES ----------------------

def test_large_tensor_performance():
    # Test with a large tensor (but <100MB)
    # 1 * 16 * 128 * 128 * 4 bytes = 1MB
    x = torch.rand(1, 16, 128, 128)
    noise = torch.rand(1, 16, 128, 128)
    lpf = torch.rand(1, 16, 128, 128)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_large_batch():
    # Test with a large batch dimension
    x = torch.rand(64, 2, 16, 16)
    noise = torch.rand(64, 2, 16, 16)
    lpf = torch.rand(64, 2, 16, 16)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_large_channel():
    # Test with a large channel dimension
    x = torch.rand(1, 128, 8, 8)
    noise = torch.rand(1, 128, 8, 8)
    lpf = torch.rand(1, 128, 8, 8)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_large_spatial():
    # Test with large spatial dimensions (but <100MB)
    x = torch.rand(1, 1, 256, 256)
    noise = torch.rand(1, 1, 256, 256)
    lpf = torch.rand(1, 1, 256, 256)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output

def test_gradient_preservation():
    # Check that gradients flow through the operation
    x = torch.ones(1, 1, 8, 8, requires_grad=True)
    noise = torch.zeros(1, 1, 8, 8, requires_grad=True)
    lpf = torch.full((1, 1, 8, 8), 0.5)
    codeflash_output = Dummy()._apply_freq_filter(x, noise, lpf); result = codeflash_output
    s = result.sum()
    s.backward()
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-FreeInitMixin._apply_freq_filter-mbdyc1ar and push.

Codeflash

Here’s an optimized version that reduces memory footprint, in-place where safe, avoids unnecessarily repeated calculations, and combines chained calls. The function logic and return value are unchanged.



**Key Optimizations:**
- Fused the FFT and shift operations for both `x` and `noise`.
- Used `torch.sub` instead of `1 - low_pass_filter` for better performance and clarity.
- Performed multiplications and addition in-place (`mul_`, `add_`) to reduce memory overhead and temporary tensors.
- Used the same tensor (`x_freq`) as the accumulator for the frequency-domain mixture.
- Reused `dims` variable to avoid repeated tuple construction.
- No extra allocations beyond what is mathematically necessary; all outputs and logic are unchanged.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Jun 1, 2025
@codeflash-ai codeflash-ai bot requested a review from aseembits93 June 1, 2025 17:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants