Skip to content

⚡️ Speed up function _map_usage by 42% #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: try-refinement
Choose a base branch
from

Conversation

codeflash-ai[bot]
Copy link

@codeflash-ai codeflash-ai bot commented Jul 22, 2025

📄 42% (0.42x) speedup for _map_usage in pydantic_ai_slim/pydantic_ai/models/anthropic.py

⏱️ Runtime : 41.7 microseconds 29.2 microseconds (best of 205 runs)

📝 Explanation and details

REFINEMENT Here is an optimized version of your code, focusing on fast type checks, avoiding unnecessary dictionary comprehensions, and minimizing lookups and function calls. Key changes.

  • Use type() comparisons for faster type matching, since you have a closed set of classes (much faster than isinstance() for known exact types based on the hit profile; works if no complex inheritance).
  • Inline the extraction of attributes directly.
  • Accumulate request token counts directly, retrieving only used keys (avoid building an intermediate details dict in advance).
  • Avoid calling .model_dump() unless there's a usage dict (model_dump() can be expensive).
  • Only collect details if really necessary, and avoid repeated get calls.

All doc/comments preserved, only changed if code was altered.

Why is this faster?

  • type() checks outperform isinstance in predictable code paths with a fixed set of input types (as profiled, these are from a controlled API).
  • The dictionary creation and token accumulation minimizes key lookup, memory allocations, and unnecessary fallback defaults.
  • No multiple .get() or .model_dump() or comprehension unless there's valid usage data.

If you must support subclassing, use isinstance(); otherwise, the above will be highest performing for this workload.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 21 Passed
🌀 Generated Regression Tests 42 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
⚙️ Existing Unit Tests and Runtime
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
models/test_anthropic.py::test_usage 18.0μs 14.2μs ✅26.0%
🌀 Generated Regression Tests and Runtime
from types import SimpleNamespace

# imports
import pytest  # used for our unit tests
from pydantic_ai.models.anthropic import _map_usage


# Mock the dependencies for isolated testing
class Usage:
    def __init__(self, request_tokens=None, response_tokens=None, total_tokens=None, details=None):
        self.request_tokens = request_tokens
        self.response_tokens = response_tokens
        self.total_tokens = total_tokens
        self.details = details

    def __eq__(self, other):
        if not isinstance(other, Usage):
            return False
        return (
            self.request_tokens == other.request_tokens and
            self.response_tokens == other.response_tokens and
            self.total_tokens == other.total_tokens and
            self.details == other.details
        )

    def __repr__(self):
        return (
            f"Usage(request_tokens={self.request_tokens}, "
            f"response_tokens={self.response_tokens}, "
            f"total_tokens={self.total_tokens}, "
            f"details={self.details})"
        )

# Mock anthropic beta types
class BetaMessage:
    def __init__(self, usage):
        self.usage = usage

class BetaRawMessageStartEvent:
    def __init__(self, message):
        self.message = message

class BetaRawMessageDeltaEvent:
    def __init__(self, usage):
        self.usage = usage

class BetaRawMessageStreamEvent:
    pass  # Used for types only

# Mock usage object with model_dump
class MockUsageObj:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
    def model_dump(self):
        return dict(self.__dict__)

# Patch usage module for test
usage = SimpleNamespace(Usage=Usage)
from pydantic_ai.models.anthropic import _map_usage

# -------------------------
# Unit tests start here
# -------------------------

# --------- BASIC TEST CASES ---------

def test_basic_beta_message_only_input_and_output_tokens():
    # Test with only input_tokens and output_tokens present
    usage_obj = MockUsageObj(input_tokens=10, output_tokens=20)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 666ns -> 416ns (60.1% faster)

def test_basic_beta_message_with_cache_creation_tokens():
    # Test with input_tokens and cache_creation_input_tokens
    usage_obj = MockUsageObj(input_tokens=5, cache_creation_input_tokens=7, output_tokens=12)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 583ns -> 375ns (55.5% faster)

def test_basic_beta_message_with_cache_read_tokens():
    # Test with input_tokens, cache_read_input_tokens, and output_tokens
    usage_obj = MockUsageObj(input_tokens=8, cache_read_input_tokens=4, output_tokens=9)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 625ns -> 416ns (50.2% faster)

def test_basic_beta_raw_message_start_event():
    # Test BetaRawMessageStartEvent with nested usage
    usage_obj = MockUsageObj(input_tokens=3, output_tokens=6)
    message = SimpleNamespace(usage=usage_obj)
    event = BetaRawMessageStartEvent(message=message)
    codeflash_output = _map_usage(event); result = codeflash_output # 625ns -> 375ns (66.7% faster)

def test_basic_beta_raw_message_delta_event():
    # Test BetaRawMessageDeltaEvent with only output_tokens
    usage_obj = MockUsageObj(output_tokens=15)
    event = BetaRawMessageDeltaEvent(usage=usage_obj)
    codeflash_output = _map_usage(event); result = codeflash_output # 625ns -> 375ns (66.7% faster)

# --------- EDGE TEST CASES ---------

def test_edge_no_usage_info():
    # Test with a type that provides no usage info
    class DummyEvent(BetaRawMessageStreamEvent):
        pass
    dummy = DummyEvent()
    codeflash_output = _map_usage(dummy); result = codeflash_output # 708ns -> 416ns (70.2% faster)

def test_edge_zero_tokens_everywhere():
    # All tokens are zero
    usage_obj = MockUsageObj(input_tokens=0, cache_creation_input_tokens=0, cache_read_input_tokens=0, output_tokens=0)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 625ns -> 375ns (66.7% faster)

def test_edge_negative_tokens():
    # Negative tokens (should be possible if API returns so, e.g., for testing)
    usage_obj = MockUsageObj(input_tokens=-5, output_tokens=-10)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 625ns -> 375ns (66.7% faster)

def test_edge_non_integer_values_ignored():
    # Non-integer values in usage should be ignored in details
    usage_obj = MockUsageObj(input_tokens=7, output_tokens=13, foo="bar", bar=1.5, baz=None)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 583ns -> 375ns (55.5% faster)


def test_edge_extra_integer_fields_in_details():
    # Extra integer fields in usage should be included in details
    usage_obj = MockUsageObj(input_tokens=2, output_tokens=3, foo=99, bar=0)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 709ns -> 458ns (54.8% faster)

def test_edge_none_fields_are_ignored():
    # Fields with None values are ignored in details
    usage_obj = MockUsageObj(input_tokens=None, output_tokens=6, foo=None)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 625ns -> 416ns (50.2% faster)

# --------- LARGE SCALE TEST CASES ---------

def test_large_many_integer_fields():
    # Test with many integer fields (simulate 1000 fields)
    fields = {f"field_{i}": i for i in range(1000)}
    fields["input_tokens"] = 100
    fields["output_tokens"] = 200
    usage_obj = MockUsageObj(**fields)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 625ns -> 375ns (66.7% faster)
    # All integer fields included in details
    for i in range(1000):
        pass

def test_large_input_and_cache_tokens():
    # Test with large values for input_tokens, cache_creation_input_tokens, cache_read_input_tokens
    usage_obj = MockUsageObj(
        input_tokens=300,
        cache_creation_input_tokens=400,
        cache_read_input_tokens=500,
        output_tokens=600
    )
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 625ns -> 375ns (66.7% faster)

def test_large_beta_raw_message_delta_event():
    # Simulate a large delta event with only output_tokens present
    usage_obj = MockUsageObj(output_tokens=999)
    event = BetaRawMessageDeltaEvent(usage=usage_obj)
    codeflash_output = _map_usage(event); result = codeflash_output # 583ns -> 416ns (40.1% faster)

def test_large_beta_raw_message_start_event_with_many_fields():
    # BetaRawMessageStartEvent with many integer fields
    fields = {f"f{i}": i for i in range(500)}
    fields["input_tokens"] = 50
    fields["output_tokens"] = 100
    usage_obj = MockUsageObj(**fields)
    message = SimpleNamespace(usage=usage_obj)
    event = BetaRawMessageStartEvent(message=message)
    codeflash_output = _map_usage(event); result = codeflash_output # 583ns -> 375ns (55.5% faster)
    for i in range(500):
        pass

def test_large_all_zero_except_output_tokens():
    # All input tokens zero, only output_tokens large
    usage_obj = MockUsageObj(input_tokens=0, cache_creation_input_tokens=0, cache_read_input_tokens=0, output_tokens=1000)
    msg = BetaMessage(usage=usage_obj)
    codeflash_output = _map_usage(msg); result = codeflash_output # 541ns -> 375ns (44.3% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from types import SimpleNamespace

# imports
import pytest  # used for our unit tests
from pydantic_ai.models.anthropic import _map_usage


# Mocks for anthropic and pydantic_ai types, since we don't have the real modules
class Usage:
    def __init__(self, request_tokens=None, response_tokens=None, total_tokens=None, details=None):
        self.request_tokens = request_tokens
        self.response_tokens = response_tokens
        self.total_tokens = total_tokens
        self.details = details

    def __eq__(self, other):
        if not isinstance(other, Usage):
            return False
        return (
            self.request_tokens == other.request_tokens and
            self.response_tokens == other.response_tokens and
            self.total_tokens == other.total_tokens and
            self.details == other.details
        )

    def __repr__(self):
        return (
            f"Usage(request_tokens={self.request_tokens}, "
            f"response_tokens={self.response_tokens}, "
            f"total_tokens={self.total_tokens}, "
            f"details={self.details})"
        )

class MockUsage:
    """Mock for anthropic usage object with model_dump method."""
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
    def model_dump(self):
        return dict(self.__dict__)

class BetaMessage:
    def __init__(self, usage):
        self.usage = usage

class BetaRawMessageStartEvent:
    def __init__(self, message):
        self.message = message

class BetaRawMessageDeltaEvent:
    def __init__(self, usage):
        self.usage = usage

class BetaRawMessageStreamEvent:
    # Base class for all stream events, used for isinstance checks
    pass

# Insert our mock Usage into a mock "usage" module
usage = SimpleNamespace(Usage=Usage)
from pydantic_ai.models.anthropic import _map_usage

# -----------------------
# Unit tests start here
# -----------------------

# ----------- BASIC TEST CASES -----------

def test_basic_beta_message_minimal():
    # Test a BetaMessage with only output_tokens and input_tokens
    usage_obj = MockUsage(output_tokens=10, input_tokens=5)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=5,
        response_tokens=10,
        total_tokens=15,
        details={'output_tokens': 10, 'input_tokens': 5}
    )
    codeflash_output = _map_usage(msg) # 750ns -> 542ns (38.4% faster)

def test_basic_beta_message_with_extra_fields():
    # Test BetaMessage with extra integer fields
    usage_obj = MockUsage(output_tokens=10, input_tokens=7, foo=3, bar=0)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=7,
        response_tokens=10,
        total_tokens=17,
        details={'output_tokens': 10, 'input_tokens': 7, 'foo': 3, 'bar': 0}
    )
    codeflash_output = _map_usage(msg) # 666ns -> 458ns (45.4% faster)

def test_basic_betarawmessagestartevent():
    # Test BetaRawMessageStartEvent with usage inside message
    usage_obj = MockUsage(output_tokens=8, input_tokens=4)
    message = SimpleNamespace(usage=usage_obj)
    event = BetaRawMessageStartEvent(message=message)
    expected = Usage(
        request_tokens=4,
        response_tokens=8,
        total_tokens=12,
        details={'output_tokens': 8, 'input_tokens': 4}
    )
    codeflash_output = _map_usage(event) # 708ns -> 417ns (69.8% faster)

def test_basic_betarawmessagedeltaevent():
    # Test BetaRawMessageDeltaEvent with only output_tokens
    usage_obj = MockUsage(output_tokens=12)
    event = BetaRawMessageDeltaEvent(usage=usage_obj)
    expected = Usage(
        request_tokens=None,
        response_tokens=12,
        total_tokens=12,
        details={'output_tokens': 12}
    )
    codeflash_output = _map_usage(event) # 625ns -> 416ns (50.2% faster)

# ----------- EDGE TEST CASES -----------

def test_no_usage_info_stream_event():
    # Test an unknown BetaRawMessageStreamEvent (should return empty Usage)
    class DummyEvent(BetaRawMessageStreamEvent):
        pass
    event = DummyEvent()
    expected = Usage()
    codeflash_output = _map_usage(event) # 708ns -> 417ns (69.8% faster)

def test_zero_tokens_everywhere():
    # All token counts are zero
    usage_obj = MockUsage(output_tokens=0, input_tokens=0)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=None,
        response_tokens=0,
        total_tokens=0,
        details={'output_tokens': 0, 'input_tokens': 0}
    )
    codeflash_output = _map_usage(msg) # 625ns -> 416ns (50.2% faster)

def test_negative_tokens():
    # Negative tokens (should be preserved in details and counts)
    usage_obj = MockUsage(output_tokens=-5, input_tokens=-7)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=-7,
        response_tokens=-5,
        total_tokens=-12,
        details={'output_tokens': -5, 'input_tokens': -7}
    )
    codeflash_output = _map_usage(msg) # 625ns -> 416ns (50.2% faster)

def test_only_cache_creation_input_tokens():
    # Only cache_creation_input_tokens is present
    usage_obj = MockUsage(output_tokens=2, cache_creation_input_tokens=9)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=9,
        response_tokens=2,
        total_tokens=11,
        details={'output_tokens': 2, 'cache_creation_input_tokens': 9}
    )
    codeflash_output = _map_usage(msg) # 625ns -> 375ns (66.7% faster)

def test_only_cache_read_input_tokens():
    # Only cache_read_input_tokens is present
    usage_obj = MockUsage(output_tokens=3, cache_read_input_tokens=11)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=11,
        response_tokens=3,
        total_tokens=14,
        details={'output_tokens': 3, 'cache_read_input_tokens': 11}
    )
    codeflash_output = _map_usage(msg) # 583ns -> 375ns (55.5% faster)

def test_multiple_input_token_fields():
    # All three input token fields present
    usage_obj = MockUsage(output_tokens=5, input_tokens=2, cache_creation_input_tokens=3, cache_read_input_tokens=4)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=2 + 3 + 4,
        response_tokens=5,
        total_tokens=14,
        details={'output_tokens': 5, 'input_tokens': 2, 'cache_creation_input_tokens': 3, 'cache_read_input_tokens': 4}
    )
    codeflash_output = _map_usage(msg) # 584ns -> 375ns (55.7% faster)

def test_non_integer_fields_ignored():
    # Non-integer fields should not appear in details
    usage_obj = MockUsage(output_tokens=7, input_tokens=3, foo='bar', bar=None, baz=2.5)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=3,
        response_tokens=7,
        total_tokens=10,
        details={'output_tokens': 7, 'input_tokens': 3}
    )
    codeflash_output = _map_usage(msg) # 583ns -> 417ns (39.8% faster)

def test_details_none_if_no_int_fields():
    # No integer fields: details should be None
    usage_obj = MockUsage(foo='bar', baz=None)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=None,
        response_tokens=None,
        total_tokens=None,
        details=None
    )
    codeflash_output = _map_usage(msg) # 583ns -> 416ns (40.1% faster)


def test_large_number_of_fields():
    # Many integer fields in usage
    fields = {f'field_{i}': i for i in range(100)}
    fields['output_tokens'] = 100
    fields['input_tokens'] = 200
    usage_obj = MockUsage(**fields)
    msg = BetaMessage(usage=usage_obj)
    expected_details = fields.copy()
    expected = Usage(
        request_tokens=200,
        response_tokens=100,
        total_tokens=300,
        details=expected_details
    )
    codeflash_output = _map_usage(msg) # 666ns -> 459ns (45.1% faster)

def test_large_token_counts():
    # Very large token counts
    usage_obj = MockUsage(output_tokens=999999, input_tokens=888888)
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=888888,
        response_tokens=999999,
        total_tokens=888888 + 999999,
        details={'output_tokens': 999999, 'input_tokens': 888888}
    )
    codeflash_output = _map_usage(msg) # 584ns -> 417ns (40.0% faster)

def test_large_scale_betarawmessagedeltaevent():
    # Large scale BetaRawMessageDeltaEvent with many details
    details = {f'foo{i}': i for i in range(500)}
    details['output_tokens'] = 1234
    usage_obj = MockUsage(**details)
    event = BetaRawMessageDeltaEvent(usage=usage_obj)
    expected_details = details.copy()
    expected = Usage(
        request_tokens=None,
        response_tokens=1234,
        total_tokens=1234,
        details=expected_details
    )
    codeflash_output = _map_usage(event) # 625ns -> 375ns (66.7% faster)

def test_large_scale_multiple_messages():
    # Test a batch of messages to check for memory/efficiency
    for i in range(10):  # 10 is enough for a unit test, but can be increased
        usage_obj = MockUsage(output_tokens=i, input_tokens=i * 2)
        msg = BetaMessage(usage=usage_obj)
        expected = Usage(
            request_tokens=i * 2,
            response_tokens=i,
            total_tokens=i * 3,
            details={'output_tokens': i, 'input_tokens': i * 2}
        )
        codeflash_output = _map_usage(msg) # 3.58μs -> 2.00μs (79.3% faster)

def test_large_scale_many_token_fields():
    # All three input token fields with large values
    usage_obj = MockUsage(
        output_tokens=500,
        input_tokens=100,
        cache_creation_input_tokens=200,
        cache_read_input_tokens=300
    )
    msg = BetaMessage(usage=usage_obj)
    expected = Usage(
        request_tokens=100 + 200 + 300,
        response_tokens=500,
        total_tokens=1100,
        details={
            'output_tokens': 500,
            'input_tokens': 100,
            'cache_creation_input_tokens': 200,
            'cache_read_input_tokens': 300
        }
    )
    codeflash_output = _map_usage(msg) # 625ns -> 416ns (50.2% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-_map_usage-mdexafbs and push.

Codeflash

REFINEMENT Here is an optimized version of your code, focusing on fast type checks, avoiding unnecessary dictionary comprehensions, and minimizing lookups and function calls. Key changes.

- Use `type()` comparisons for faster type matching, since you have a closed set of classes (much faster than `isinstance()` for known exact types based on the hit profile; works if no complex inheritance).
- Inline the extraction of attributes directly.
- Accumulate request token counts directly, retrieving only used keys (avoid building an intermediate `details` dict in advance).
- Avoid calling `.model_dump()` unless there's a usage dict (`model_dump()` can be expensive).
- Only collect details if really necessary, and avoid repeated `get` calls.

All doc/comments preserved, only changed if code was altered.



**Why is this faster?**
- `type()` checks outperform `isinstance` in predictable code paths with a fixed set of input types (as profiled, these are from a controlled API).
- The dictionary creation and token accumulation minimizes key lookup, memory allocations, and unnecessary fallback defaults.
- No multiple `.get()` or `.model_dump()` or comprehension unless there's valid usage data.

If you must support subclassing, use `isinstance()`; otherwise, the above will be highest performing for this workload.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Jul 22, 2025
@codeflash-ai codeflash-ai bot requested a review from aseembits93 July 22, 2025 19:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡️ codeflash Optimization PR opened by Codeflash AI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0 participants