Skip to content

Commit d6c0f30

Browse files
committed
fix: handle max_completeion_tokens error for newer openai models
1 parent 8501a49 commit d6c0f30

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

src/ragas/llms/base.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,36 @@ def __init__(
594594
# Check if client is async-capable at initialization
595595
self.is_async = self._check_client_async()
596596

597+
def _map_openai_params(self, model_args: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
598+
"""Map max_tokens to max_completion_tokens for o-series and newer OpenAI models.
599+
600+
O-series models (o1, o3, etc.) and some newer models like gpt-5-mini
601+
require max_completion_tokens instead of the deprecated max_tokens parameter.
602+
"""
603+
mapped_args = model_args.copy()
604+
605+
# List of models that require max_completion_tokens
606+
models_requiring_max_completion_tokens = [
607+
"o1",
608+
"o3",
609+
"o1-mini",
610+
"o3-mini",
611+
"gpt-5",
612+
"gpt-5-mini",
613+
]
614+
615+
# Check if the model matches any of the patterns
616+
model_lower = self.model.lower()
617+
requires_max_completion_tokens = any(
618+
pattern in model_lower for pattern in models_requiring_max_completion_tokens
619+
)
620+
621+
# If max_tokens is provided and model requires max_completion_tokens, map it
622+
if requires_max_completion_tokens and "max_tokens" in mapped_args:
623+
mapped_args["max_completion_tokens"] = mapped_args.pop("max_tokens")
624+
625+
return mapped_args
626+
597627
def _check_client_async(self) -> bool:
598628
"""Determine if the client is async-capable."""
599629
try:
@@ -699,11 +729,13 @@ def generate(
699729
**google_kwargs,
700730
)
701731
else:
732+
# Map parameters for OpenAI models requiring max_completion_tokens
733+
openai_kwargs = self._map_openai_params(self.model_args)
702734
result = self.client.chat.completions.create(
703735
model=self.model,
704736
messages=messages,
705737
response_model=response_model,
706-
**self.model_args,
738+
**openai_kwargs,
707739
)
708740

709741
# Track the usage
@@ -755,11 +787,13 @@ async def agenerate(
755787
**google_kwargs,
756788
)
757789
else:
790+
# Map parameters for OpenAI models requiring max_completion_tokens
791+
openai_kwargs = self._map_openai_params(self.model_args)
758792
result = await self.client.chat.completions.create(
759793
model=self.model,
760794
messages=messages,
761795
response_model=response_model,
762-
**self.model_args,
796+
**openai_kwargs,
763797
)
764798

765799
# Track the usage

0 commit comments

Comments
 (0)