Skip to content

Add dynamic weighting function to UsageLimits #3189

@otto-sellerstam

Description

@otto-sellerstam

Description

The UsageLimits class currently only takes absolute token values (input, output, and total) as hard limits. I propose to add an additional utility function attribute usage_limit_func: Callable[[RunUsage], bool] | None = None (or similarly named) taking a RunUsage object and returning a bool, like

def usage_limit_func(usage: RunUsage) -> bool: ...

which would be checked in UsageLimits's check_tokens method. If defined, I'd propose that this takes priority over the other usage limit constraints, as

def check_tokens(self, usage: RunUsage) -> None:
    """Raises a \`UsageLimitExceeded\` exception if the usage exceeds any of the token limits."""
    if self.usage_limit_func is not None and self.usage_limit_func(usage):
        raise UsageLimitExceeded(f'Exceeded the usage_limit_func limit')
  
    input_tokens = usage.input_tokens
    if self.input_tokens_limit is not None and input_tokens > self.input_tokens_limit:
        raise UsageLimitExceeded(f'Exceeded the input_tokens_limit of {self.input_tokens_limit} ({input_tokens=})')
  
    output_tokens = usage.output_tokens
    if self.output_tokens_limit is not None and output_tokens > self.output_tokens_limit:
        raise UsageLimitExceeded(
            f'Exceeded the output_tokens_limit of {self.output_tokens_limit} ({output_tokens=})'
        )
  
    total_tokens = usage.total_tokens
    if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
        raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')

Use cases

This would allow the user to set dynamic weighting of tokens for their limits to take into account, for example, pricing differences between the token types, as well as removing cache read tokens from the input tokens.

Examples

Removing cache_read_tokens from input_tokens

max_input_tokens = 1_000_000  # Excluding cache_read_tokens
def usage_limit_func(usage: RunUsage) -> bool:
    """Return True if usage limits are exceeded, False otherwise."""
    return usage.input_tokens - usage.cache_read_tokens > max_input_tokens

usage_limit = UsageLimits(usage_limit_func=usage_limit_func)

Combining and scaling input_tokens and output_tokens

max_weighted_tokens = 1_000_000
def usage_limit_func(usage: RunUsage) -> bool:
    """Return True if usage limits are exceeded, False otherwise."""
    return usage.input_tokens + 3 * usage.output_tokens > max_weighted_tokens

usage_limit = UsageLimits(usage_limit_func=usage_limit_func)

Dollar-based budget limit

max_cost_usd = 0.10
def usage_limit_func(usage: RunUsage) -> bool:
    """Returns True if estimated cost exceeds budget."""
    cost = usage.input_tokens * 0.000001 + usage.output_tokens * 0.000003
    return cost > max_cost_usd
usage_limit = UsageLimits(usage_limit_func=usage_limit_func)

References

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions