Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 116 additions & 36 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,110 @@ def __init__(
# Check if client is async-capable at initialization
self.is_async = self._check_client_async()

def _map_provider_params(self) -> t.Dict[str, t.Any]:
"""Route to provider-specific parameter mapping.

Each provider may have different parameter requirements:
- Google: Wraps parameters in generation_config and renames max_tokens
- OpenAI: Maps max_tokens to max_completion_tokens for o-series models
- Anthropic: No special handling required (pass-through)
- LiteLLM: No special handling required (routes internally, pass-through)
"""
provider_lower = self.provider.lower()

if provider_lower == "google":
return self._map_google_params()
elif provider_lower == "openai":
return self._map_openai_params()
else:
# Anthropic, LiteLLM - pass through unchanged
return self.model_args.copy()

def _map_openai_params(self) -> t.Dict[str, t.Any]:
"""Map max_tokens to max_completion_tokens for OpenAI reasoning models.

Reasoning models (o-series and gpt-5 series) require max_completion_tokens
instead of the deprecated max_tokens parameter when using Chat Completions API.

Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.

Pattern-based matching for future-proof coverage:
- O-series: o1, o2, o3, o4, o5, ... (all reasoning versions)
- GPT-5 series: gpt-5, gpt-5-*, gpt-6, gpt-7, ... (all GPT-5+ models)
- Other: codex-mini
"""
mapped_args = self.model_args.copy()

model_lower = self.model.lower()

# Pattern-based detection for reasoning models that require max_completion_tokens
# Uses prefix matching to cover current and future model variants
def is_reasoning_model(model_str: str) -> bool:
"""Check if model is a reasoning model requiring max_completion_tokens."""
# O-series reasoning models (o1, o1-mini, o1-2024-12-17, o2, o3, o4, o5, o6, o7, o8, o9)
# Pattern: "o" followed by single digit 1-9, then optional "-" or end of string
# TODO: Update to support o10+ when OpenAI releases models beyond o9
if (
len(model_str) >= 2
and model_str[0] == "o"
and model_str[1] in "123456789"
):
# Allow single digit o-series: o1, o2, ..., o9
if len(model_str) == 2 or model_str[2] in ("-", "_"):
return True

# GPT-5 and newer generation models (gpt-5, gpt-5-*, gpt-6, gpt-7, ..., gpt-19)
# Pattern: "gpt-" followed by single or double digit >= 5, max 19
# TODO: Update to support gpt-20+ when OpenAI releases models beyond gpt-19
if model_str.startswith("gpt-"):
version_str = (
model_str[4:].split("-")[0].split("_")[0]
) # Get version number
try:
version = int(version_str)
if 5 <= version <= 19:
return True
except ValueError:
pass

# Other specific reasoning models
if model_str == "codex-mini":
return True

return False

requires_max_completion_tokens = is_reasoning_model(model_lower)

# If max_tokens is provided and model requires max_completion_tokens, map it
if requires_max_completion_tokens and "max_tokens" in mapped_args:
mapped_args["max_completion_tokens"] = mapped_args.pop("max_tokens")

return mapped_args

def _map_google_params(self) -> t.Dict[str, t.Any]:
"""Map parameters for Google Gemini models.

Google models require parameters to be wrapped in a generation_config dict,
and max_tokens is renamed to max_output_tokens.
"""
google_kwargs = {}
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
generation_config = {}

for key, value in self.model_args.items():
if key in generation_config_keys:
if key == "max_tokens":
generation_config["max_output_tokens"] = value
else:
generation_config[key] = value
else:
google_kwargs[key] = value

if generation_config:
google_kwargs["generation_config"] = generation_config

return google_kwargs

def _check_client_async(self) -> bool:
"""Determine if the client is async-capable."""
try:
Expand Down Expand Up @@ -676,34 +780,22 @@ def generate(
self.agenerate(prompt, response_model)
)
else:
if self.provider.lower() == "google":
google_kwargs = {}
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
generation_config = {}

for key, value in self.model_args.items():
if key in generation_config_keys:
if key == "max_tokens":
generation_config["max_output_tokens"] = value
else:
generation_config[key] = value
else:
google_kwargs[key] = value

if generation_config:
google_kwargs["generation_config"] = generation_config
# Map parameters based on provider requirements
provider_kwargs = self._map_provider_params()

if self.provider.lower() == "google":
result = self.client.create(
messages=messages,
response_model=response_model,
**google_kwargs,
**provider_kwargs,
)
else:
# OpenAI, Anthropic, LiteLLM
result = self.client.chat.completions.create(
model=self.model,
messages=messages,
response_model=response_model,
**self.model_args,
**provider_kwargs,
)

# Track the usage
Expand Down Expand Up @@ -732,34 +824,22 @@ async def agenerate(
"Cannot use agenerate() with a synchronous client. Use generate() instead."
)

if self.provider.lower() == "google":
google_kwargs = {}
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
generation_config = {}

for key, value in self.model_args.items():
if key in generation_config_keys:
if key == "max_tokens":
generation_config["max_output_tokens"] = value
else:
generation_config[key] = value
else:
google_kwargs[key] = value

if generation_config:
google_kwargs["generation_config"] = generation_config
# Map parameters based on provider requirements
provider_kwargs = self._map_provider_params()

if self.provider.lower() == "google":
result = await self.client.create(
messages=messages,
response_model=response_model,
**google_kwargs,
**provider_kwargs,
)
else:
# OpenAI, Anthropic, LiteLLM
result = await self.client.chat.completions.create(
model=self.model,
messages=messages,
response_model=response_model,
**self.model_args,
**provider_kwargs,
)

# Track the usage
Expand Down
Loading