Skip to content

Commit faa87e2

Browse files
committed
fix: handle max_completeion_tokens error for newer openai models (#2413)
1 parent dbf91ab commit faa87e2

File tree

1 file changed

+116
-36
lines changed

1 file changed

+116
-36
lines changed

src/ragas/llms/base.py

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,110 @@ def __init__(
744744
# Check if client is async-capable at initialization
745745
self.is_async = self._check_client_async()
746746

747+
def _map_provider_params(self) -> t.Dict[str, t.Any]:
748+
"""Route to provider-specific parameter mapping.
749+
750+
Each provider may have different parameter requirements:
751+
- Google: Wraps parameters in generation_config and renames max_tokens
752+
- OpenAI: Maps max_tokens to max_completion_tokens for o-series models
753+
- Anthropic: No special handling required (pass-through)
754+
- LiteLLM: No special handling required (routes internally, pass-through)
755+
"""
756+
provider_lower = self.provider.lower()
757+
758+
if provider_lower == "google":
759+
return self._map_google_params()
760+
elif provider_lower == "openai":
761+
return self._map_openai_params()
762+
else:
763+
# Anthropic, LiteLLM - pass through unchanged
764+
return self.model_args.copy()
765+
766+
def _map_openai_params(self) -> t.Dict[str, t.Any]:
767+
"""Map max_tokens to max_completion_tokens for OpenAI reasoning models.
768+
769+
Reasoning models (o-series and gpt-5 series) require max_completion_tokens
770+
instead of the deprecated max_tokens parameter when using Chat Completions API.
771+
772+
Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
773+
774+
Pattern-based matching for future-proof coverage:
775+
- O-series: o1, o2, o3, o4, o5, ... (all reasoning versions)
776+
- GPT-5 series: gpt-5, gpt-5-*, gpt-6, gpt-7, ... (all GPT-5+ models)
777+
- Other: codex-mini
778+
"""
779+
mapped_args = self.model_args.copy()
780+
781+
model_lower = self.model.lower()
782+
783+
# Pattern-based detection for reasoning models that require max_completion_tokens
784+
# Uses prefix matching to cover current and future model variants
785+
def is_reasoning_model(model_str: str) -> bool:
786+
"""Check if model is a reasoning model requiring max_completion_tokens."""
787+
# O-series reasoning models (o1, o1-mini, o1-2024-12-17, o2, o3, o4, o5, o6, o7, o8, o9)
788+
# Pattern: "o" followed by single digit 1-9, then optional "-" or end of string
789+
# TODO: Update to support o10+ when OpenAI releases models beyond o9
790+
if (
791+
len(model_str) >= 2
792+
and model_str[0] == "o"
793+
and model_str[1] in "123456789"
794+
):
795+
# Allow single digit o-series: o1, o2, ..., o9
796+
if len(model_str) == 2 or model_str[2] in ("-", "_"):
797+
return True
798+
799+
# GPT-5 and newer generation models (gpt-5, gpt-5-*, gpt-6, gpt-7, ..., gpt-19)
800+
# Pattern: "gpt-" followed by single or double digit >= 5, max 19
801+
# TODO: Update to support gpt-20+ when OpenAI releases models beyond gpt-19
802+
if model_str.startswith("gpt-"):
803+
version_str = (
804+
model_str[4:].split("-")[0].split("_")[0]
805+
) # Get version number
806+
try:
807+
version = int(version_str)
808+
if 5 <= version <= 19:
809+
return True
810+
except ValueError:
811+
pass
812+
813+
# Other specific reasoning models
814+
if model_str == "codex-mini":
815+
return True
816+
817+
return False
818+
819+
requires_max_completion_tokens = is_reasoning_model(model_lower)
820+
821+
# If max_tokens is provided and model requires max_completion_tokens, map it
822+
if requires_max_completion_tokens and "max_tokens" in mapped_args:
823+
mapped_args["max_completion_tokens"] = mapped_args.pop("max_tokens")
824+
825+
return mapped_args
826+
827+
def _map_google_params(self) -> t.Dict[str, t.Any]:
828+
"""Map parameters for Google Gemini models.
829+
830+
Google models require parameters to be wrapped in a generation_config dict,
831+
and max_tokens is renamed to max_output_tokens.
832+
"""
833+
google_kwargs = {}
834+
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
835+
generation_config = {}
836+
837+
for key, value in self.model_args.items():
838+
if key in generation_config_keys:
839+
if key == "max_tokens":
840+
generation_config["max_output_tokens"] = value
841+
else:
842+
generation_config[key] = value
843+
else:
844+
google_kwargs[key] = value
845+
846+
if generation_config:
847+
google_kwargs["generation_config"] = generation_config
848+
849+
return google_kwargs
850+
747851
def _check_client_async(self) -> bool:
748852
"""Determine if the client is async-capable."""
749853
try:
@@ -826,34 +930,22 @@ def generate(
826930
self.agenerate(prompt, response_model)
827931
)
828932
else:
829-
if self.provider.lower() == "google":
830-
google_kwargs = {}
831-
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
832-
generation_config = {}
833-
834-
for key, value in self.model_args.items():
835-
if key in generation_config_keys:
836-
if key == "max_tokens":
837-
generation_config["max_output_tokens"] = value
838-
else:
839-
generation_config[key] = value
840-
else:
841-
google_kwargs[key] = value
842-
843-
if generation_config:
844-
google_kwargs["generation_config"] = generation_config
933+
# Map parameters based on provider requirements
934+
provider_kwargs = self._map_provider_params()
845935

936+
if self.provider.lower() == "google":
846937
result = self.client.create(
847938
messages=messages,
848939
response_model=response_model,
849-
**google_kwargs,
940+
**provider_kwargs,
850941
)
851942
else:
943+
# OpenAI, Anthropic, LiteLLM
852944
result = self.client.chat.completions.create(
853945
model=self.model,
854946
messages=messages,
855947
response_model=response_model,
856-
**self.model_args,
948+
**provider_kwargs,
857949
)
858950

859951
# Track the usage
@@ -882,34 +974,22 @@ async def agenerate(
882974
"Cannot use agenerate() with a synchronous client. Use generate() instead."
883975
)
884976

885-
if self.provider.lower() == "google":
886-
google_kwargs = {}
887-
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
888-
generation_config = {}
889-
890-
for key, value in self.model_args.items():
891-
if key in generation_config_keys:
892-
if key == "max_tokens":
893-
generation_config["max_output_tokens"] = value
894-
else:
895-
generation_config[key] = value
896-
else:
897-
google_kwargs[key] = value
898-
899-
if generation_config:
900-
google_kwargs["generation_config"] = generation_config
977+
# Map parameters based on provider requirements
978+
provider_kwargs = self._map_provider_params()
901979

980+
if self.provider.lower() == "google":
902981
result = await self.client.create(
903982
messages=messages,
904983
response_model=response_model,
905-
**google_kwargs,
984+
**provider_kwargs,
906985
)
907986
else:
987+
# OpenAI, Anthropic, LiteLLM
908988
result = await self.client.chat.completions.create(
909989
model=self.model,
910990
messages=messages,
911991
response_model=response_model,
912-
**self.model_args,
992+
**provider_kwargs,
913993
)
914994

915995
# Track the usage

0 commit comments

Comments
 (0)