Skip to content

Conversation

codeflash-ai[bot]
Copy link

@codeflash-ai codeflash-ai bot commented Oct 23, 2025

📄 93% (0.93x) speedup for get_cross_encoder_activation_function in python/sglang/srt/layers/activation.py

⏱️ Runtime : 26.0 milliseconds 13.5 milliseconds (best of 79 runs)

📝 Explanation and details

The optimization adds an @lru_cache(maxsize=128) decorator to the resolve_obj_by_qualname function in sglang/utils.py. This caching mechanism provides a 92% speedup by eliminating redundant module imports and attribute lookups.

Key optimization: The line profiler shows that importlib.import_module(module_name) was the primary bottleneck, consuming 96.9% of execution time in the original code. Module imports are expensive operations that involve file system access, parsing, and Python's import machinery. With caching, subsequent calls to resolve the same qualified name (like "torch.nn.modules.activation.ReLU") bypass the import entirely and return the cached result.

Performance impact: The cache reduces the critical line from 94.25ms to 32.3ms total execution time - a 66% reduction in the most expensive operation. This is particularly effective for workloads that repeatedly request the same activation functions, as shown in the test cases where 1000+ configs use identical activation functions.

Why it works: The cache is safe because module imports are idempotent - importing the same module multiple times returns the same object. The maxsize=128 limit provides sufficient capacity for typical activation function variety while preventing unbounded memory growth. This optimization is most beneficial for batch processing scenarios and repeated model instantiation with common activation functions.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 3533 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import importlib
from typing import Any

# imports
import pytest  # used for our unit tests
import torch.nn as nn
from sglang.srt.layers.activation import get_cross_encoder_activation_function


class PretrainedConfig:
    """
    Minimal PretrainedConfig stub for testing.
    """
    def __init__(self, sbert_ce_default_activation_function=None):
        self.sbert_ce_default_activation_function = sbert_ce_default_activation_function
from sglang.srt.layers.activation import get_cross_encoder_activation_function

# unit tests

# -------------------------------
# Basic Test Cases
# -------------------------------

def test_identity_returned_when_no_activation_specified():
    # Test: config with no sbert_ce_default_activation_function returns nn.Identity
    config = PretrainedConfig()
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_identity_returned_when_activation_is_none():
    # Test: config with sbert_ce_default_activation_function=None returns nn.Identity
    config = PretrainedConfig(sbert_ce_default_activation_function=None)
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_valid_activation_function_relu():
    # Test: config with valid torch.nn.modules.ReLU returns nn.ReLU instance
    config = PretrainedConfig(sbert_ce_default_activation_function="torch.nn.modules.activation.ReLU")
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_valid_activation_function_gelu():
    # Test: config with valid torch.nn.modules.GELU returns nn.GELU instance
    config = PretrainedConfig(sbert_ce_default_activation_function="torch.nn.modules.activation.GELU")
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_valid_activation_function_sigmoid():
    # Test: config with valid torch.nn.modules.Sigmoid returns nn.Sigmoid instance
    config = PretrainedConfig(sbert_ce_default_activation_function="torch.nn.modules.activation.Sigmoid")
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

# -------------------------------
# Edge Test Cases
# -------------------------------

def test_activation_function_wrong_prefix_raises():
    # Test: config with activation function not starting with torch.nn.modules raises assertion
    config = PretrainedConfig(sbert_ce_default_activation_function="torch.nn.functional.relu")
    try:
        get_cross_encoder_activation_function(config)
    except AssertionError as e:
        pass
    else:
        raise AssertionError("Should raise AssertionError for wrong prefix")

def test_activation_function_empty_string_raises():
    # Test: config with empty string as activation function raises assertion
    config = PretrainedConfig(sbert_ce_default_activation_function="")
    try:
        get_cross_encoder_activation_function(config)
    except AssertionError as e:
        pass
    else:
        raise AssertionError("Should raise AssertionError for empty string")

def test_activation_function_nonexistent_raises_importerror():
    # Test: config with non-existent activation function raises ImportError or AttributeError
    config = PretrainedConfig(sbert_ce_default_activation_function="torch.nn.modules.activation.FakeActivation")
    try:
        get_cross_encoder_activation_function(config)
    except AttributeError:
        pass  # Expected, as FakeActivation does not exist
    except ImportError:
        pass  # Acceptable, if module import fails
    else:
        raise AssertionError("Should raise AttributeError or ImportError for non-existent activation")

def test_config_missing_attribute_returns_identity():
    # Test: config without sbert_ce_default_activation_function attribute returns nn.Identity
    class DummyConfig:
        pass
    config = DummyConfig()
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_config_attribute_is_not_string_raises():
    # Test: config with sbert_ce_default_activation_function as an integer raises AttributeError
    config = PretrainedConfig(sbert_ce_default_activation_function=123)
    try:
        get_cross_encoder_activation_function(config)
    except AttributeError:
        pass  # resolve_obj_by_qualname will fail
    except Exception as e:
        pass
    else:
        raise AssertionError("Should raise an exception for non-string activation function")

# -------------------------------
# Large Scale Test Cases
# -------------------------------

def test_multiple_activation_functions():
    # Test: running the function for a large number of valid activation functions
    # Only use activations present in torch.nn.modules.activation
    valid_activations = [
        "ReLU", "GELU", "Sigmoid", "Tanh", "Softmax", "Softplus", "Softsign", "ELU", "SELU", "CELU", "LeakyReLU"
    ]
    for act_name in valid_activations:
        qualname = f"torch.nn.modules.activation.{act_name}"
        config = PretrainedConfig(sbert_ce_default_activation_function=qualname)
        codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_large_batch_of_configs_returns_identity_or_module():
    # Test: running the function for a large batch of configs with mixed activation functions
    configs = []
    for i in range(500):
        if i % 2 == 0:
            configs.append(PretrainedConfig(sbert_ce_default_activation_function=None))
        else:
            configs.append(PretrainedConfig(sbert_ce_default_activation_function="torch.nn.modules.activation.ReLU"))
    for i, config in enumerate(configs):
        codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output
        if i % 2 == 0:
            pass
        else:
            pass

def test_performance_under_large_scale():
    # Test: function performance with a batch of 1000 configs (all valid)
    configs = [PretrainedConfig(sbert_ce_default_activation_function="torch.nn.modules.activation.GELU") for _ in range(1000)]
    for config in configs:
        codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

# -------------------------------
# Additional Edge Cases
# -------------------------------

def test_activation_function_is_whitespace_string_raises():
    # Test: config with whitespace string as activation function raises assertion
    config = PretrainedConfig(sbert_ce_default_activation_function="   ")
    try:
        get_cross_encoder_activation_function(config)
    except AssertionError as e:
        pass
    else:
        raise AssertionError("Should raise AssertionError for whitespace string")

def test_activation_function_is_non_string_type_raises():
    # Test: config with activation function as a list raises TypeError
    config = PretrainedConfig(sbert_ce_default_activation_function=["torch.nn.modules.activation.ReLU"])
    try:
        get_cross_encoder_activation_function(config)
    except Exception as e:
        pass
    else:
        raise AssertionError("Should raise an exception for non-string activation function type")

def test_activation_function_is_valid_but_not_module_raises():
    # Test: config with activation function pointing to a valid module but not a class raises TypeError
    config = PretrainedConfig(sbert_ce_default_activation_function="torch.nn.modules.activation")
    try:
        get_cross_encoder_activation_function(config)
    except TypeError:
        pass
    except Exception as e:
        pass
    else:
        raise AssertionError("Should raise an exception when qualname is not a class")
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import importlib
from typing import Any

# imports
import pytest  # used for our unit tests
import torch.nn as nn
from sglang.srt.layers.activation import get_cross_encoder_activation_function


class DummyConfig:
    """A dummy config class to simulate PretrainedConfig for testing."""
    def __init__(self, sbert_ce_default_activation_function=None):
        self.sbert_ce_default_activation_function = sbert_ce_default_activation_function
from sglang.srt.layers.activation import get_cross_encoder_activation_function

# unit tests

# Basic Test Cases

def test_returns_identity_when_config_has_no_attr():
    """Test that nn.Identity is returned when config has no sbert_ce_default_activation_function attribute."""
    class NoAttrConfig:
        pass
    codeflash_output = get_cross_encoder_activation_function(NoAttrConfig()); act_fn = codeflash_output

def test_returns_identity_when_attr_is_none():
    """Test that nn.Identity is returned when sbert_ce_default_activation_function is None."""
    config = DummyConfig()
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_returns_correct_activation_function_relu():
    """Test that ReLU is returned when sbert_ce_default_activation_function is set to torch.nn.modules.activation.ReLU."""
    config = DummyConfig("torch.nn.modules.activation.ReLU")
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_returns_correct_activation_function_gelu():
    """Test that GELU is returned when sbert_ce_default_activation_function is set to torch.nn.modules.activation.GELU."""
    config = DummyConfig("torch.nn.modules.activation.GELU")
    codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output


def test_invalid_module_prefix_raises():
    """Test that an invalid module prefix raises an AssertionError."""
    config = DummyConfig("torch.nn.ReLU")
    with pytest.raises(AssertionError):
        get_cross_encoder_activation_function(config)

def test_nonexistent_activation_function_raises():
    """Test that a non-existent activation function raises AttributeError."""
    config = DummyConfig("torch.nn.modules.activation.NotAFunction")
    with pytest.raises(AttributeError):
        get_cross_encoder_activation_function(config)

def test_non_string_activation_function_raises():
    """Test that a non-string activation function value raises AttributeError."""
    config = DummyConfig(12345)
    # Should fail at .startswith
    with pytest.raises(AttributeError):
        get_cross_encoder_activation_function(config)

def test_empty_string_activation_function_raises():
    """Test that an empty string activation function value raises AssertionError."""
    config = DummyConfig("")
    with pytest.raises(AssertionError):
        get_cross_encoder_activation_function(config)




def test_activation_function_with_leading_dot_raises():
    """Test that an activation function with leading dot raises AttributeError."""
    config = DummyConfig(".torch.nn.modules.activation.ReLU")
    with pytest.raises(AssertionError):
        get_cross_encoder_activation_function(config)

# Large Scale Test Cases

def test_large_number_of_configs_identity():
    """Test multiple configs with None activation function for scalability."""
    configs = [DummyConfig() for _ in range(500)]
    for config in configs:
        codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output

def test_large_number_of_configs_relu():
    """Test multiple configs with ReLU activation function for scalability."""
    configs = [DummyConfig("torch.nn.modules.activation.ReLU") for _ in range(500)]
    for config in configs:
        codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output


def test_large_scale_memory_efficiency():
    """Test that large number of activations do not exceed memory limits."""
    configs = [DummyConfig("torch.nn.modules.activation.ReLU") for _ in range(1000)]
    for i, config in enumerate(configs):
        codeflash_output = get_cross_encoder_activation_function(config); act_fn = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-get_cross_encoder_activation_function-mh2ueub9 and push.

Codeflash

The optimization adds an `@lru_cache(maxsize=128)` decorator to the `resolve_obj_by_qualname` function in `sglang/utils.py`. This caching mechanism provides a **92% speedup** by eliminating redundant module imports and attribute lookups.

**Key optimization:** The line profiler shows that `importlib.import_module(module_name)` was the primary bottleneck, consuming 96.9% of execution time in the original code. Module imports are expensive operations that involve file system access, parsing, and Python's import machinery. With caching, subsequent calls to resolve the same qualified name (like "torch.nn.modules.activation.ReLU") bypass the import entirely and return the cached result.

**Performance impact:** The cache reduces the critical line from 94.25ms to 32.3ms total execution time - a 66% reduction in the most expensive operation. This is particularly effective for workloads that repeatedly request the same activation functions, as shown in the test cases where 1000+ configs use identical activation functions.

**Why it works:** The cache is safe because module imports are idempotent - importing the same module multiple times returns the same object. The maxsize=128 limit provides sufficient capacity for typical activation function variety while preventing unbounded memory growth. This optimization is most beneficial for batch processing scenarios and repeated model instantiation with common activation functions.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 23, 2025 03:08
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Oct 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants