Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Jun 1, 2025

📄 96% (0.96x) speedup for retrieve_timesteps in src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py

⏱️ Runtime : 642 microseconds 328 microseconds (best of 703 runs)

📝 Explanation and details

Here is an optimized version of your program that reduces runtime by avoiding repetitive heavy calls to inspect.signature and set() in each invocation.
Instead, it uses per-class caching for parameter accept checks. The function logic, signature, comments, and exception messages are unchanged. This optimization is fully compatible with all types for scheduler.

Optimization summary:

  • The slowest lines were the "timesteps" in set(inspect.signature(...).parameters.keys()) and "sigmas" in set(...) checks, as shown in your line profile.
  • These are replaced with a fast cache lookup and a direct dict membership check, so the cost of inspection is paid only once per scheduler class per parameter.
  • All function signatures, comments, and exception handling are preserved. The output remains 100% equivalent.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 20 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests Details
import inspect
from typing import List, Optional, Union

# imports
import pytest  # used for our unit tests
import torch
from src.diffusers.pipelines.deprecated.alt_diffusion.pipeline_alt_diffusion import \
    retrieve_timesteps

# ------------------------
# Mock Schedulers for Testing
# ------------------------

class SchedulerBasic:
    """A basic scheduler that supports only num_inference_steps."""
    def __init__(self):
        self.timesteps = None
        self.called_with = {}

    def set_timesteps(self, num_inference_steps, device=None, **kwargs):
        self.called_with = {
            "num_inference_steps": num_inference_steps,
            "device": device,
            **kwargs
        }
        # Generate timesteps as a tensor from num_inference_steps-1 to 0
        self.timesteps = torch.arange(num_inference_steps-1, -1, -1, dtype=torch.long)
        if device is not None:
            self.timesteps = self.timesteps.to(device)

class SchedulerTimesteps:
    """A scheduler that supports custom timesteps."""
    def __init__(self):
        self.timesteps = None
        self.called_with = {}

    def set_timesteps(self, timesteps, device=None, **kwargs):
        self.called_with = {
            "timesteps": timesteps,
            "device": device,
            **kwargs
        }
        self.timesteps = torch.tensor(timesteps, dtype=torch.long)
        if device is not None:
            self.timesteps = self.timesteps.to(device)

class SchedulerSigmas:
    """A scheduler that supports custom sigmas."""
    def __init__(self):
        self.timesteps = None
        self.called_with = {}

    def set_timesteps(self, sigmas, device=None, **kwargs):
        self.called_with = {
            "sigmas": sigmas,
            "device": device,
            **kwargs
        }
        # For testing, just set timesteps as the indices of sigmas
        self.timesteps = torch.arange(len(sigmas), dtype=torch.long)
        if device is not None:
            self.timesteps = self.timesteps.to(device)

class SchedulerTimestepsAndSigmas:
    """A scheduler that supports both custom timesteps and sigmas."""
    def __init__(self):
        self.timesteps = None
        self.called_with = {}

    def set_timesteps(self, timesteps=None, sigmas=None, device=None, **kwargs):
        self.called_with = {
            "timesteps": timesteps,
            "sigmas": sigmas,
            "device": device,
            **kwargs
        }
        if timesteps is not None:
            self.timesteps = torch.tensor(timesteps, dtype=torch.long)
        elif sigmas is not None:
            self.timesteps = torch.arange(len(sigmas), dtype=torch.long)
        else:
            self.timesteps = torch.arange(10, dtype=torch.long)
        if device is not None:
            self.timesteps = self.timesteps.to(device)

class SchedulerNoTimestepsOrSigmas:
    """A scheduler that does not support timesteps or sigmas as arguments."""
    def __init__(self):
        self.timesteps = None
        self.called_with = {}

    def set_timesteps(self, num_inference_steps, device=None, **kwargs):
        self.called_with = {
            "num_inference_steps": num_inference_steps,
            "device": device,
            **kwargs
        }
        self.timesteps = torch.arange(num_inference_steps-1, -1, -1, dtype=torch.long)
        if device is not None:
            self.timesteps = self.timesteps.to(device)

# ------------------------
# Basic Test Cases
# ------------------------

def test_basic_num_inference_steps():
    """Test basic usage with num_inference_steps only."""
    scheduler = SchedulerBasic()
    timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=5)

def test_basic_custom_timesteps():
    """Test providing custom timesteps."""
    scheduler = SchedulerTimesteps()
    custom_ts = [10, 5, 2]
    timesteps, n_steps = retrieve_timesteps(scheduler, timesteps=custom_ts)

def test_basic_custom_sigmas():
    """Test providing custom sigmas."""
    scheduler = SchedulerSigmas()
    custom_sigmas = [0.1, 0.2, 0.3, 0.4]
    timesteps, n_steps = retrieve_timesteps(scheduler, sigmas=custom_sigmas)

def test_basic_device_cpu():
    """Test that device argument moves tensor to cpu."""
    scheduler = SchedulerBasic()
    timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=4, device="cpu")

def test_basic_device_cuda(monkeypatch):
    """Test that device argument moves tensor to cuda (if available)."""
    # Patch torch.Tensor.to so we don't need a real GPU
    called = {}
    orig_to = torch.Tensor.to
    def fake_to(self, device):
        called['device'] = device
        return orig_to(self, "cpu")  # always return cpu tensor for test
    monkeypatch.setattr(torch.Tensor, "to", fake_to)
    scheduler = SchedulerBasic()
    timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=3, device="cuda:0")

def test_basic_kwargs_passed():
    """Test that extra kwargs are passed to scheduler.set_timesteps."""
    scheduler = SchedulerBasic()
    timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=2, foo="bar", baz=42)

# ------------------------
# Edge Test Cases
# ------------------------

def test_timesteps_and_sigmas_raises():
    """Test that passing both timesteps and sigmas raises ValueError."""
    scheduler = SchedulerTimestepsAndSigmas()
    with pytest.raises(ValueError) as e:
        retrieve_timesteps(scheduler, timesteps=[1,2], sigmas=[0.1,0.2])

def test_timesteps_not_supported_raises():
    """Test that passing timesteps to a scheduler that doesn't support them raises ValueError."""
    scheduler = SchedulerBasic()
    with pytest.raises(ValueError) as e:
        retrieve_timesteps(scheduler, timesteps=[1,2,3])

def test_sigmas_not_supported_raises():
    """Test that passing sigmas to a scheduler that doesn't support them raises ValueError."""
    scheduler = SchedulerBasic()
    with pytest.raises(ValueError) as e:
        retrieve_timesteps(scheduler, sigmas=[0.1,0.2])

def test_timesteps_none_and_sigmas_none_and_num_inference_steps_none():
    """Test that if all three are None, set_timesteps is called with None."""
    scheduler = SchedulerNoTimestepsOrSigmas()
    timesteps, n_steps = retrieve_timesteps(scheduler)

def test_empty_timesteps_list():
    """Test that an empty timesteps list works and returns empty tensor."""
    scheduler = SchedulerTimesteps()
    timesteps, n_steps = retrieve_timesteps(scheduler, timesteps=[])

def test_empty_sigmas_list():
    """Test that an empty sigmas list works and returns empty tensor."""
    scheduler = SchedulerSigmas()
    timesteps, n_steps = retrieve_timesteps(scheduler, sigmas=[])

def test_negative_timesteps():
    """Test that negative timesteps are handled correctly."""
    scheduler = SchedulerTimesteps()
    custom_ts = [-5, -1, 0, 1]
    timesteps, n_steps = retrieve_timesteps(scheduler, timesteps=custom_ts)

def test_noninteger_timesteps():
    """Test that non-integer timesteps raise an error in the scheduler."""
    class BadScheduler:
        def set_timesteps(self, timesteps, device=None, **kwargs):
            # Simulate error if non-integer in timesteps
            if any(not isinstance(x, int) for x in timesteps):
                raise TypeError("All timesteps must be integers")
            self.timesteps = torch.tensor(timesteps, dtype=torch.long)
    scheduler = BadScheduler()
    with pytest.raises(TypeError):
        retrieve_timesteps(scheduler, timesteps=[1, 2.5, 3])

def test_nonfloat_sigmas():
    """Test that non-float sigmas raise an error in the scheduler."""
    class BadScheduler:
        def set_timesteps(self, sigmas, device=None, **kwargs):
            if any(not isinstance(x, float) for x in sigmas):
                raise TypeError("All sigmas must be floats")
            self.timesteps = torch.arange(len(sigmas), dtype=torch.long)
    scheduler = BadScheduler()
    with pytest.raises(TypeError):
        retrieve_timesteps(scheduler, sigmas=[0.1, 2, 0.3])

def test_kwargs_override():
    """Test that kwargs can override device argument."""
    class CustomScheduler:
        def __init__(self):
            self.timesteps = None
            self.called_with = {}
        def set_timesteps(self, num_inference_steps, device=None, **kwargs):
            self.called_with = {"num_inference_steps": num_inference_steps, "device": device, **kwargs}
            self.timesteps = torch.arange(num_inference_steps, dtype=torch.long)
            if device is not None:
                self.timesteps = self.timesteps.to(device)
    scheduler = CustomScheduler()
    timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=5, device="cpu", device_override="cuda:0")

# ------------------------
# Large Scale Test Cases
# ------------------------

def test_large_num_inference_steps():
    """Test with a large number of inference steps (up to 1000)."""
    scheduler = SchedulerBasic()
    n = 1000
    timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=n)

def test_large_custom_timesteps():
    """Test with a large custom timesteps list (up to 1000 elements)."""
    scheduler = SchedulerTimesteps()
    custom_ts = list(range(1000))
    timesteps, n_steps = retrieve_timesteps(scheduler, timesteps=custom_ts)

def test_large_custom_sigmas():
    """Test with a large custom sigmas list (up to 1000 elements)."""
    scheduler = SchedulerSigmas()
    custom_sigmas = [float(i)/1000 for i in range(1000)]
    timesteps, n_steps = retrieve_timesteps(scheduler, sigmas=custom_sigmas)

def test_large_device_cpu():
    """Test with large tensor and device='cpu'."""
    scheduler = SchedulerBasic()
    n = 1000
    timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=n, device="cpu")

def test_large_kwargs():
    """Test that large kwargs dict is handled."""
    scheduler = SchedulerBasic()
    large_kwargs = {f"key{i}": i for i in range(50)}
    timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=10, **large_kwargs)
    for i in range(50):
        pass

def test_large_multiple_calls_consistency():
    """Test that multiple large calls are consistent and do not leak state."""
    scheduler = SchedulerBasic()
    for n in [100, 200, 300]:
        timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=n)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

import inspect
from typing import List, Optional, Union

# imports
import pytest  # used for our unit tests
import torch
from src.diffusers.pipelines.deprecated.alt_diffusion.pipeline_alt_diffusion import \
    retrieve_timesteps

# ---------------------------
# Mock scheduler definitions
# ---------------------------

class SchedulerBasic:
    """A scheduler supporting only num_inference_steps."""
    def __init__(self):
        self.timesteps = None
        self.last_args = None
    def set_timesteps(self, num_inference_steps, device=None, **kwargs):
        # Save arguments for inspection
        self.last_args = (num_inference_steps, device, kwargs)
        # Timesteps are descending ints from num_inference_steps-1 to 0
        self.timesteps = torch.arange(num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)

class SchedulerWithTimesteps:
    """A scheduler supporting custom timesteps."""
    def __init__(self):
        self.timesteps = None
        self.last_args = None
    def set_timesteps(self, timesteps, device=None, **kwargs):
        # Save arguments for inspection
        self.last_args = (timesteps, device, kwargs)
        # Convert list to tensor
        self.timesteps = torch.tensor(timesteps, device=device, dtype=torch.long)

class SchedulerWithSigmas:
    """A scheduler supporting custom sigmas."""
    def __init__(self):
        self.timesteps = None
        self.last_args = None
    def set_timesteps(self, sigmas, device=None, **kwargs):
        # Save arguments for inspection
        self.last_args = (sigmas, device, kwargs)
        # Use sigmas as float tensor
        self.timesteps = torch.tensor(sigmas, device=device, dtype=torch.float)

class SchedulerWithTimestepsAndSigmas:
    """A scheduler supporting both custom timesteps and sigmas."""
    def __init__(self):
        self.timesteps = None
        self.last_args = None
    def set_timesteps(self, num_inference_steps=None, timesteps=None, sigmas=None, device=None, **kwargs):
        self.last_args = (num_inference_steps, timesteps, sigmas, device, kwargs)
        if timesteps is not None:
            self.timesteps = torch.tensor(timesteps, device=device, dtype=torch.long)
        elif sigmas is not None:
            self.timesteps = torch.tensor(sigmas, device=device, dtype=torch.float)
        elif num_inference_steps is not None:
            self.timesteps = torch.arange(num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
        else:
            raise ValueError("No valid input for set_timesteps")

# ---------------------------
# Unit tests
# ---------------------------

# ---------------------------
# Basic Test Cases
# ---------------------------

def test_basic_num_inference_steps():
    # Test standard usage with num_inference_steps on a scheduler that supports only num_inference_steps
    scheduler = SchedulerBasic()
    steps = 10
    timesteps, nsteps = retrieve_timesteps(scheduler, num_inference_steps=steps)

def test_basic_custom_timesteps():
    # Test with custom timesteps on a scheduler that supports timesteps
    scheduler = SchedulerWithTimesteps()
    custom_steps = [0, 2, 4, 6]
    timesteps, nsteps = retrieve_timesteps(scheduler, timesteps=custom_steps)

def test_basic_custom_sigmas():
    # Test with custom sigmas on a scheduler that supports sigmas
    scheduler = SchedulerWithSigmas()
    custom_sigmas = [0.1, 0.5, 0.9]
    timesteps, nsteps = retrieve_timesteps(scheduler, sigmas=custom_sigmas)

def test_basic_both_timesteps_and_sigmas_supported():
    # Test with a scheduler supporting both timesteps and sigmas
    scheduler = SchedulerWithTimestepsAndSigmas()
    # Use timesteps
    custom_steps = [1, 3, 5]
    timesteps, nsteps = retrieve_timesteps(scheduler, timesteps=custom_steps)
    # Use sigmas
    custom_sigmas = [0.2, 0.4]
    timesteps2, nsteps2 = retrieve_timesteps(scheduler, sigmas=custom_sigmas)
    # Use num_inference_steps
    timesteps3, nsteps3 = retrieve_timesteps(scheduler, num_inference_steps=4)

def test_basic_device_argument_cpu():
    # Test that device is passed through and works (CPU)
    scheduler = SchedulerWithTimesteps()
    custom_steps = [7, 8, 9]
    timesteps, nsteps = retrieve_timesteps(scheduler, timesteps=custom_steps, device="cpu")

# ---------------------------
# Edge Test Cases
# ---------------------------

def test_edge_timesteps_and_sigmas_both_given():
    # Should raise ValueError if both timesteps and sigmas are given
    scheduler = SchedulerWithTimestepsAndSigmas()
    with pytest.raises(ValueError) as excinfo:
        retrieve_timesteps(scheduler, timesteps=[1,2], sigmas=[0.1, 0.2])

def test_edge_timesteps_not_supported():
    # Should raise ValueError if timesteps are given but scheduler does not support them
    scheduler = SchedulerBasic()
    with pytest.raises(ValueError) as excinfo:
        retrieve_timesteps(scheduler, timesteps=[1,2,3])

def test_edge_sigmas_not_supported():
    # Should raise ValueError if sigmas are given but scheduler does not support them
    scheduler = SchedulerWithTimesteps()
    with pytest.raises(ValueError) as excinfo:
        retrieve_timesteps(scheduler, sigmas=[0.1, 0.2])

def test_edge_num_inference_steps_zero():
    # Should handle num_inference_steps=0 (returns empty tensor)
    scheduler = SchedulerBasic()
    timesteps, nsteps = retrieve_timesteps(scheduler, num_inference_steps=0)

def test_edge_empty_timesteps():
    # Should handle empty timesteps list
    scheduler = SchedulerWithTimesteps()
    timesteps, nsteps = retrieve_timesteps(scheduler, timesteps=[])

def test_edge_empty_sigmas():
    # Should handle empty sigmas list
    scheduler = SchedulerWithSigmas()
    timesteps, nsteps = retrieve_timesteps(scheduler, sigmas=[])

def test_edge_device_argument_cuda(monkeypatch):
    # Test with device="cuda" (if available), otherwise skip
    if torch.cuda.is_available():
        scheduler = SchedulerWithTimesteps()
        custom_steps = [1, 2, 3]
        timesteps, nsteps = retrieve_timesteps(scheduler, timesteps=custom_steps, device="cuda")


def test_edge_scheduler_set_timesteps_missing():
    # Should raise AttributeError if scheduler does not have set_timesteps at all
    class SchedulerNoSetTimesteps:
        pass
    scheduler = SchedulerNoSetTimesteps()
    with pytest.raises(AttributeError):
        retrieve_timesteps(scheduler, num_inference_steps=5)


def test_edge_kwargs_conflict():
    # Should not raise if extra kwargs are ignored by set_timesteps
    class SchedulerIgnoreKwargs:
        def __init__(self):
            self.timesteps = None
        def set_timesteps(self, num_inference_steps, device=None, **kwargs):
            self.timesteps = torch.arange(num_inference_steps - 1, -1, -1, device=device)
    scheduler = SchedulerIgnoreKwargs()
    timesteps, nsteps = retrieve_timesteps(scheduler, num_inference_steps=3, foo="bar")

# ---------------------------
# Large Scale Test Cases
# ---------------------------

def test_large_num_inference_steps():
    # Test with a large number of inference steps (e.g., 1000)
    scheduler = SchedulerBasic()
    steps = 1000
    timesteps, nsteps = retrieve_timesteps(scheduler, num_inference_steps=steps)

def test_large_custom_timesteps():
    # Test with a large custom timesteps list (e.g., 1000 elements)
    scheduler = SchedulerWithTimesteps()
    custom_steps = list(range(1000))
    timesteps, nsteps = retrieve_timesteps(scheduler, timesteps=custom_steps)

def test_large_custom_sigmas():
    # Test with a large custom sigmas list (e.g., 1000 elements)
    scheduler = SchedulerWithSigmas()
    custom_sigmas = [float(i) / 1000 for i in range(1000)]
    timesteps, nsteps = retrieve_timesteps(scheduler, sigmas=custom_sigmas)

def test_large_device_argument_cpu():
    # Test large input with device="cpu"
    scheduler = SchedulerWithTimesteps()
    custom_steps = list(range(1000))
    timesteps, nsteps = retrieve_timesteps(scheduler, timesteps=custom_steps, device="cpu")
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-retrieve_timesteps-mbdmw0d7 and push.

Codeflash

Here is an optimized version of your program that reduces runtime by avoiding repetitive heavy calls to `inspect.signature` and `set()` in each invocation.
Instead, it uses per-class caching for parameter accept checks. The function logic, signature, comments, and exception messages are unchanged. This optimization is fully compatible with all types for `scheduler`.



**Optimization summary:**
- The slowest lines were the `"timesteps" in set(inspect.signature(...).parameters.keys())` and `"sigmas" in set(...)` checks, as shown in your line profile.
- These are replaced with a fast cache lookup and a direct dict membership check, so the cost of inspection is paid only once per scheduler class per parameter.
- All function signatures, comments, and exception handling are preserved. The output remains 100% equivalent.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Jun 1, 2025
@codeflash-ai codeflash-ai bot requested a review from aseembits93 June 1, 2025 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants