Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Jun 1, 2025

📄 88% (0.88x) speedup for retrieve_timesteps in src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

⏱️ Runtime : 817 microseconds 434 microseconds (best of 629 runs)

📝 Explanation and details

Optimization summary.

  • Profile showed inspect.signature to be slow and called every function call. Replaced with persistent class-level caches for timesteps and sigmas key presence per scheduler class.
  • No change to API, output, or signatures. Output and error behavior is preserved. All required comments retained.
  • CPU- and I/O-intensive lines now run only once per scheduler class, making repeated calls much faster.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 23 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests Details
import inspect
from typing import List, Optional, Union

# imports
import pytest  # used for our unit tests
import torch
from src.diffusers.pipelines.cogvideo.pipeline_cogvideox_video2video import \
    retrieve_timesteps

# --- Mock Schedulers for Testing ---

class BasicScheduler:
    """
    Scheduler supporting only num_inference_steps.
    """
    def __init__(self):
        self.timesteps = None
        self.called_args = None

    def set_timesteps(self, num_inference_steps, device=None, **kwargs):
        # Save args for inspection
        self.called_args = (num_inference_steps, device)
        # Create a simple decreasing sequence
        self.timesteps = torch.arange(num_inference_steps-1, -1, -1, dtype=torch.long)
        if device is not None:
            self.timesteps = self.timesteps.to(device)

class TimestepsScheduler:
    """
    Scheduler supporting custom timesteps.
    """
    def __init__(self):
        self.timesteps = None
        self.called_args = None

    def set_timesteps(self, timesteps, device=None, **kwargs):
        self.called_args = (timesteps, device)
        self.timesteps = torch.tensor(timesteps, dtype=torch.long)
        if device is not None:
            self.timesteps = self.timesteps.to(device)

class SigmasScheduler:
    """
    Scheduler supporting custom sigmas.
    """
    def __init__(self):
        self.timesteps = None
        self.called_args = None

    def set_timesteps(self, sigmas, device=None, **kwargs):
        self.called_args = (sigmas, device)
        # For testing, just store the sigmas as a tensor
        self.timesteps = torch.tensor(sigmas, dtype=torch.float)
        if device is not None:
            self.timesteps = self.timesteps.to(device)

class BothScheduler:
    """
    Scheduler supporting both custom timesteps and sigmas.
    """
    def __init__(self):
        self.timesteps = None
        self.called_args = None

    def set_timesteps(self, num_inference_steps=None, timesteps=None, sigmas=None, device=None, **kwargs):
        self.called_args = (num_inference_steps, timesteps, sigmas, device)
        if timesteps is not None:
            self.timesteps = torch.tensor(timesteps, dtype=torch.long)
        elif sigmas is not None:
            self.timesteps = torch.tensor(sigmas, dtype=torch.float)
        elif num_inference_steps is not None:
            self.timesteps = torch.arange(num_inference_steps-1, -1, -1, dtype=torch.long)
        else:
            raise ValueError("No valid argument provided to set_timesteps.")
        if device is not None:
            self.timesteps = self.timesteps.to(device)

# --------- UNIT TESTS ---------

# --------- BASIC TEST CASES ---------

def test_basic_num_inference_steps():
    # Test with num_inference_steps only, using BasicScheduler
    scheduler = BasicScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=5)

def test_basic_custom_timesteps():
    # Test with custom timesteps, using TimestepsScheduler
    scheduler = TimestepsScheduler()
    custom_ts = [10, 5, 1]
    timesteps, n = retrieve_timesteps(scheduler, timesteps=custom_ts)

def test_basic_custom_sigmas():
    # Test with custom sigmas, using SigmasScheduler
    scheduler = SigmasScheduler()
    custom_sigmas = [0.5, 0.2, 0.1]
    timesteps, n = retrieve_timesteps(scheduler, sigmas=custom_sigmas)

def test_basic_bothscheduler_num_inference_steps():
    # BothScheduler supports all modes
    scheduler = BothScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=4)

def test_basic_bothscheduler_timesteps():
    scheduler = BothScheduler()
    custom_ts = [8,4,2]
    timesteps, n = retrieve_timesteps(scheduler, timesteps=custom_ts)

def test_basic_bothscheduler_sigmas():
    scheduler = BothScheduler()
    custom_sigmas = [0.9, 0.5]
    timesteps, n = retrieve_timesteps(scheduler, sigmas=custom_sigmas)

# --------- EDGE TEST CASES ---------

def test_timesteps_and_sigmas_raises():
    # Should raise if both timesteps and sigmas are provided
    scheduler = BothScheduler()
    with pytest.raises(ValueError) as excinfo:
        retrieve_timesteps(scheduler, timesteps=[1,2], sigmas=[0.1,0.2])

def test_timesteps_not_supported_raises():
    # Scheduler does not support custom timesteps
    scheduler = BasicScheduler()
    with pytest.raises(ValueError) as excinfo:
        retrieve_timesteps(scheduler, timesteps=[1,2])

def test_sigmas_not_supported_raises():
    # Scheduler does not support custom sigmas
    scheduler = BasicScheduler()
    with pytest.raises(ValueError) as excinfo:
        retrieve_timesteps(scheduler, sigmas=[0.1,0.2])

def test_none_num_inference_steps():
    # Should work if num_inference_steps is None but scheduler doesn't require it
    class DummyScheduler:
        def __init__(self):
            self.timesteps = torch.tensor([0])
        def set_timesteps(self, num_inference_steps, device=None, **kwargs):
            # Accepts None, sets timesteps to [0]
            self.timesteps = torch.tensor([0])
    scheduler = DummyScheduler()
    timesteps, n = retrieve_timesteps(scheduler)

def test_empty_timesteps_list():
    # Should work and return empty tensor if timesteps is []
    scheduler = TimestepsScheduler()
    timesteps, n = retrieve_timesteps(scheduler, timesteps=[])

def test_empty_sigmas_list():
    # Should work and return empty tensor if sigmas is []
    scheduler = SigmasScheduler()
    timesteps, n = retrieve_timesteps(scheduler, sigmas=[])

def test_device_argument_cpu():
    # Should move timesteps to cpu if device='cpu'
    scheduler = BothScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=3, device='cpu')

def test_device_argument_cuda(monkeypatch):
    # Should move timesteps to cuda if device='cuda', if cuda is available
    scheduler = BothScheduler()
    if torch.cuda.is_available():
        timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=2, device='cuda')
    else:
        # If cuda is not available, skip
        pytest.skip("CUDA not available on this machine.")

def test_kwargs_passed_to_scheduler():
    # Should pass through kwargs to scheduler.set_timesteps
    class KwargScheduler:
        def __init__(self):
            self.timesteps = None
            self.kwargs = None
        def set_timesteps(self, num_inference_steps, device=None, **kwargs):
            self.timesteps = torch.arange(num_inference_steps-1, -1, -1)
            self.kwargs = kwargs
    scheduler = KwargScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=2, foo='bar', baz=42)

def test_scheduler_returns_non_tensor_timesteps():
    # Should handle if scheduler.timesteps is not a tensor (should not happen in real use)
    class WeirdScheduler:
        def __init__(self):
            self.timesteps = None
        def set_timesteps(self, num_inference_steps, device=None, **kwargs):
            self.timesteps = [3,2,1]
    scheduler = WeirdScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=3)

def test_scheduler_set_timesteps_missing_raises():
    # Should raise if scheduler has no set_timesteps method
    class NoSetTimesteps:
        pass
    scheduler = NoSetTimesteps()
    with pytest.raises(AttributeError):
        retrieve_timesteps(scheduler, num_inference_steps=3)

# --------- LARGE SCALE TEST CASES ---------

def test_large_num_inference_steps():
    # Test with a large num_inference_steps (1000)
    scheduler = BasicScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=1000)

def test_large_custom_timesteps():
    # Test with a large custom timesteps list (1000 elements)
    scheduler = TimestepsScheduler()
    custom_ts = list(range(1000, 0, -1))
    timesteps, n = retrieve_timesteps(scheduler, timesteps=custom_ts)

def test_large_custom_sigmas():
    # Test with a large custom sigmas list (1000 elements)
    scheduler = SigmasScheduler()
    custom_sigmas = [float(i)/1000 for i in range(1000)]
    timesteps, n = retrieve_timesteps(scheduler, sigmas=custom_sigmas)

def test_large_device_argument_cpu():
    # Test large scale with device argument
    scheduler = BothScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=1000, device='cpu')

def test_large_kwargs():
    # Test large scale with extra kwargs
    class LargeKwargScheduler:
        def __init__(self):
            self.timesteps = None
            self.kwargs = None
        def set_timesteps(self, num_inference_steps, device=None, **kwargs):
            self.timesteps = torch.arange(num_inference_steps-1, -1, -1)
            self.kwargs = kwargs
    scheduler = LargeKwargScheduler()
    # Pass a large kwarg dict
    extra_kwargs = {f'key{i}': i for i in range(20)}
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=1000, **extra_kwargs)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

import inspect
from typing import List, Optional, Union

# imports
import pytest  # used for our unit tests
import torch
from src.diffusers.pipelines.cogvideo.pipeline_cogvideox_video2video import \
    retrieve_timesteps

# --- Dummy Schedulers for testing ---

class BasicScheduler:
    """
    Scheduler with set_timesteps(num_inference_steps, device=None, **kwargs)
    """
    def __init__(self):
        self.timesteps = None
        self.device = None
        self.called_with = {}

    def set_timesteps(self, num_inference_steps, device=None, **kwargs):
        self.device = device
        self.called_with = dict(num_inference_steps=num_inference_steps, device=device, **kwargs)
        # Timesteps is a tensor of [0, ..., num_inference_steps-1]
        self.timesteps = torch.arange(num_inference_steps)

class TimestepsScheduler:
    """
    Scheduler with set_timesteps(timesteps, device=None, **kwargs)
    """
    def __init__(self):
        self.timesteps = None
        self.device = None
        self.called_with = {}

    def set_timesteps(self, timesteps, device=None, **kwargs):
        self.device = device
        self.called_with = dict(timesteps=timesteps, device=device, **kwargs)
        # Timesteps is a tensor of the supplied timesteps
        self.timesteps = torch.tensor(timesteps)

class SigmasScheduler:
    """
    Scheduler with set_timesteps(sigmas, device=None, **kwargs)
    """
    def __init__(self):
        self.timesteps = None
        self.device = None
        self.called_with = {}

    def set_timesteps(self, sigmas, device=None, **kwargs):
        self.device = device
        self.called_with = dict(sigmas=sigmas, device=device, **kwargs)
        # Timesteps is a tensor of the same length as sigmas, filled with -1 for test
        self.timesteps = torch.full((len(sigmas),), -1.0)

class MultiScheduler:
    """
    Scheduler with set_timesteps(num_inference_steps=None, timesteps=None, sigmas=None, device=None, **kwargs)
    """
    def __init__(self):
        self.timesteps = None
        self.device = None
        self.called_with = {}

    def set_timesteps(self, num_inference_steps=None, timesteps=None, sigmas=None, device=None, **kwargs):
        self.device = device
        self.called_with = dict(num_inference_steps=num_inference_steps, timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
        if timesteps is not None:
            self.timesteps = torch.tensor(timesteps)
        elif sigmas is not None:
            self.timesteps = torch.full((len(sigmas),), -1.0)
        elif num_inference_steps is not None:
            self.timesteps = torch.arange(num_inference_steps)
        else:
            self.timesteps = torch.tensor([])

class NoTimestepsScheduler:
    """
    Scheduler with set_timesteps(num_inference_steps, device=None, **kwargs) but no timesteps attribute
    """
    def set_timesteps(self, num_inference_steps, device=None, **kwargs):
        pass

# --- Unit tests ---

# -------------------------
# 1. Basic Test Cases
# -------------------------

def test_basic_num_inference_steps():
    # Test with BasicScheduler and num_inference_steps
    scheduler = BasicScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=5)

def test_basic_timesteps():
    # Test with TimestepsScheduler and custom timesteps
    scheduler = TimestepsScheduler()
    custom = [10, 20, 30]
    timesteps, n = retrieve_timesteps(scheduler, timesteps=custom)

def test_basic_sigmas():
    # Test with SigmasScheduler and custom sigmas
    scheduler = SigmasScheduler()
    custom_sigmas = [0.1, 0.2, 0.3, 0.4]
    timesteps, n = retrieve_timesteps(scheduler, sigmas=custom_sigmas)

def test_basic_multi_scheduler_timesteps():
    # Test with MultiScheduler and custom timesteps
    scheduler = MultiScheduler()
    custom = [1, 2, 3, 4]
    timesteps, n = retrieve_timesteps(scheduler, timesteps=custom)

def test_basic_multi_scheduler_sigmas():
    # Test with MultiScheduler and custom sigmas
    scheduler = MultiScheduler()
    sigmas = [0.5, 0.6]
    timesteps, n = retrieve_timesteps(scheduler, sigmas=sigmas)

def test_basic_multi_scheduler_num_inference_steps():
    # Test with MultiScheduler and num_inference_steps
    scheduler = MultiScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=3)

def test_device_argument():
    # Test that device argument is passed through
    scheduler = BasicScheduler()
    device = torch.device("cpu")
    retrieve_timesteps(scheduler, num_inference_steps=2, device=device)

# -------------------------
# 2. Edge Test Cases
# -------------------------

def test_both_timesteps_and_sigmas_raises():
    # Should raise ValueError if both timesteps and sigmas are provided
    scheduler = MultiScheduler()
    with pytest.raises(ValueError):
        retrieve_timesteps(scheduler, timesteps=[1], sigmas=[0.1])

def test_timesteps_not_supported_raises():
    # Should raise ValueError if timesteps is passed but scheduler does not accept it
    scheduler = BasicScheduler()
    with pytest.raises(ValueError):
        retrieve_timesteps(scheduler, timesteps=[1,2,3])

def test_sigmas_not_supported_raises():
    # Should raise ValueError if sigmas is passed but scheduler does not accept it
    scheduler = TimestepsScheduler()
    with pytest.raises(ValueError):
        retrieve_timesteps(scheduler, sigmas=[0.1,0.2])

def test_no_timesteps_attribute():
    # If scheduler does not set timesteps attribute, should raise AttributeError
    scheduler = NoTimestepsScheduler()
    with pytest.raises(AttributeError):
        retrieve_timesteps(scheduler, num_inference_steps=1)

def test_empty_timesteps():
    # Should work with empty timesteps
    scheduler = TimestepsScheduler()
    timesteps, n = retrieve_timesteps(scheduler, timesteps=[])

def test_empty_sigmas():
    # Should work with empty sigmas
    scheduler = SigmasScheduler()
    timesteps, n = retrieve_timesteps(scheduler, sigmas=[])

def test_zero_num_inference_steps():
    # Should work with num_inference_steps=0
    scheduler = BasicScheduler()
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=0)

def test_kwargs_passed_through():
    # Check that kwargs are passed to scheduler.set_timesteps
    scheduler = BasicScheduler()
    retrieve_timesteps(scheduler, num_inference_steps=2, foo="bar")

def test_device_as_string():
    # Test device as string
    scheduler = BasicScheduler()
    retrieve_timesteps(scheduler, num_inference_steps=1, device="cpu")

def test_only_one_of_timesteps_or_sigmas():
    # Should succeed if only one of timesteps or sigmas is given
    scheduler = MultiScheduler()
    timesteps, n = retrieve_timesteps(scheduler, timesteps=[1,2,3])
    timesteps, n = retrieve_timesteps(scheduler, sigmas=[0.1,0.2])

# -------------------------
# 3. Large Scale Test Cases
# -------------------------

def test_large_num_inference_steps():
    # Test with large num_inference_steps (e.g., 1000)
    scheduler = BasicScheduler()
    N = 1000
    timesteps, n = retrieve_timesteps(scheduler, num_inference_steps=N)

def test_large_custom_timesteps():
    # Test with large custom timesteps (e.g., 1000 elements)
    scheduler = TimestepsScheduler()
    custom = list(range(1000, 2000))
    timesteps, n = retrieve_timesteps(scheduler, timesteps=custom)

def test_large_custom_sigmas():
    # Test with large custom sigmas (e.g., 1000 elements)
    scheduler = SigmasScheduler()
    sigmas = [float(i)/1000 for i in range(1000)]
    timesteps, n = retrieve_timesteps(scheduler, sigmas=sigmas)

def test_large_multi_scheduler_timesteps():
    # Test MultiScheduler with large timesteps
    scheduler = MultiScheduler()
    custom = list(range(500, 1500))
    timesteps, n = retrieve_timesteps(scheduler, timesteps=custom)

def test_large_multi_scheduler_sigmas():
    # Test MultiScheduler with large sigmas
    scheduler = MultiScheduler()
    sigmas = [float(i)/1000 for i in range(1000)]
    timesteps, n = retrieve_timesteps(scheduler, sigmas=sigmas)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-retrieve_timesteps-mbdom5ys and push.

Codeflash

**Optimization summary**.
- Profile showed `inspect.signature` to be slow and called every function call. Replaced with persistent class-level caches for `timesteps` and `sigmas` key presence per scheduler class.
- No change to API, output, or signatures. Output and error behavior is preserved. All required comments retained.  
- CPU- and I/O-intensive lines now run only once per scheduler class, making repeated calls much faster.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Jun 1, 2025
@codeflash-ai codeflash-ai bot requested a review from aseembits93 June 1, 2025 13:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants