From d6c0f3075d824ea1829ca9b88a50ab82469fed04 Mon Sep 17 00:00:00 2001 From: Kumar Anirudha Date: Mon, 10 Nov 2025 12:28:52 +0530 Subject: [PATCH 1/2] fix: handle max_completeion_tokens error for newer openai models --- src/ragas/llms/base.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 3fd3811f5..e37aa7da2 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -594,6 +594,36 @@ def __init__( # Check if client is async-capable at initialization self.is_async = self._check_client_async() + def _map_openai_params(self, model_args: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + """Map max_tokens to max_completion_tokens for o-series and newer OpenAI models. + + O-series models (o1, o3, etc.) and some newer models like gpt-5-mini + require max_completion_tokens instead of the deprecated max_tokens parameter. + """ + mapped_args = model_args.copy() + + # List of models that require max_completion_tokens + models_requiring_max_completion_tokens = [ + "o1", + "o3", + "o1-mini", + "o3-mini", + "gpt-5", + "gpt-5-mini", + ] + + # Check if the model matches any of the patterns + model_lower = self.model.lower() + requires_max_completion_tokens = any( + pattern in model_lower for pattern in models_requiring_max_completion_tokens + ) + + # If max_tokens is provided and model requires max_completion_tokens, map it + if requires_max_completion_tokens and "max_tokens" in mapped_args: + mapped_args["max_completion_tokens"] = mapped_args.pop("max_tokens") + + return mapped_args + def _check_client_async(self) -> bool: """Determine if the client is async-capable.""" try: @@ -699,11 +729,13 @@ def generate( **google_kwargs, ) else: + # Map parameters for OpenAI models requiring max_completion_tokens + openai_kwargs = self._map_openai_params(self.model_args) result = self.client.chat.completions.create( model=self.model, messages=messages, response_model=response_model, - **self.model_args, + **openai_kwargs, ) # Track the usage @@ -755,11 +787,13 @@ async def agenerate( **google_kwargs, ) else: + # Map parameters for OpenAI models requiring max_completion_tokens + openai_kwargs = self._map_openai_params(self.model_args) result = await self.client.chat.completions.create( model=self.model, messages=messages, response_model=response_model, - **self.model_args, + **openai_kwargs, ) # Track the usage From 5c2e4e9483fe378c722d7aaf7b4642c66a734723 Mon Sep 17 00:00:00 2001 From: Kumar Anirudha Date: Mon, 10 Nov 2025 12:57:17 +0530 Subject: [PATCH 2/2] refactor: implement provider-specific parameter mapping for InstructorLLM --- src/ragas/llms/base.py | 164 ++++++++++++++++++++++++++--------------- 1 file changed, 105 insertions(+), 59 deletions(-) diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index e37aa7da2..2d55e0f7f 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -594,29 +594,79 @@ def __init__( # Check if client is async-capable at initialization self.is_async = self._check_client_async() - def _map_openai_params(self, model_args: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - """Map max_tokens to max_completion_tokens for o-series and newer OpenAI models. + def _map_provider_params(self) -> t.Dict[str, t.Any]: + """Route to provider-specific parameter mapping. + + Each provider may have different parameter requirements: + - Google: Wraps parameters in generation_config and renames max_tokens + - OpenAI: Maps max_tokens to max_completion_tokens for o-series models + - Anthropic: No special handling required (pass-through) + - LiteLLM: No special handling required (routes internally, pass-through) + """ + provider_lower = self.provider.lower() + + if provider_lower == "google": + return self._map_google_params() + elif provider_lower == "openai": + return self._map_openai_params() + else: + # Anthropic, LiteLLM - pass through unchanged + return self.model_args.copy() + + def _map_openai_params(self) -> t.Dict[str, t.Any]: + """Map max_tokens to max_completion_tokens for OpenAI reasoning models. - O-series models (o1, o3, etc.) and some newer models like gpt-5-mini - require max_completion_tokens instead of the deprecated max_tokens parameter. + Reasoning models (o-series and gpt-5 series) require max_completion_tokens + instead of the deprecated max_tokens parameter when using Chat Completions API. + + Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged. + + Pattern-based matching for future-proof coverage: + - O-series: o1, o2, o3, o4, o5, ... (all reasoning versions) + - GPT-5 series: gpt-5, gpt-5-*, gpt-6, gpt-7, ... (all GPT-5+ models) + - Other: codex-mini """ - mapped_args = model_args.copy() - - # List of models that require max_completion_tokens - models_requiring_max_completion_tokens = [ - "o1", - "o3", - "o1-mini", - "o3-mini", - "gpt-5", - "gpt-5-mini", - ] + mapped_args = self.model_args.copy() - # Check if the model matches any of the patterns model_lower = self.model.lower() - requires_max_completion_tokens = any( - pattern in model_lower for pattern in models_requiring_max_completion_tokens - ) + + # Pattern-based detection for reasoning models that require max_completion_tokens + # Uses prefix matching to cover current and future model variants + def is_reasoning_model(model_str: str) -> bool: + """Check if model is a reasoning model requiring max_completion_tokens.""" + # O-series reasoning models (o1, o1-mini, o1-2024-12-17, o2, o3, o4, o5, o6, o7, o8, o9) + # Pattern: "o" followed by single digit 1-9, then optional "-" or end of string + # TODO: Update to support o10+ when OpenAI releases models beyond o9 + if ( + len(model_str) >= 2 + and model_str[0] == "o" + and model_str[1] in "123456789" + ): + # Allow single digit o-series: o1, o2, ..., o9 + if len(model_str) == 2 or model_str[2] in ("-", "_"): + return True + + # GPT-5 and newer generation models (gpt-5, gpt-5-*, gpt-6, gpt-7, ..., gpt-19) + # Pattern: "gpt-" followed by single or double digit >= 5, max 19 + # TODO: Update to support gpt-20+ when OpenAI releases models beyond gpt-19 + if model_str.startswith("gpt-"): + version_str = ( + model_str[4:].split("-")[0].split("_")[0] + ) # Get version number + try: + version = int(version_str) + if 5 <= version <= 19: + return True + except ValueError: + pass + + # Other specific reasoning models + if model_str == "codex-mini": + return True + + return False + + requires_max_completion_tokens = is_reasoning_model(model_lower) # If max_tokens is provided and model requires max_completion_tokens, map it if requires_max_completion_tokens and "max_tokens" in mapped_args: @@ -624,6 +674,30 @@ def _map_openai_params(self, model_args: t.Dict[str, t.Any]) -> t.Dict[str, t.An return mapped_args + def _map_google_params(self) -> t.Dict[str, t.Any]: + """Map parameters for Google Gemini models. + + Google models require parameters to be wrapped in a generation_config dict, + and max_tokens is renamed to max_output_tokens. + """ + google_kwargs = {} + generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"} + generation_config = {} + + for key, value in self.model_args.items(): + if key in generation_config_keys: + if key == "max_tokens": + generation_config["max_output_tokens"] = value + else: + generation_config[key] = value + else: + google_kwargs[key] = value + + if generation_config: + google_kwargs["generation_config"] = generation_config + + return google_kwargs + def _check_client_async(self) -> bool: """Determine if the client is async-capable.""" try: @@ -706,36 +780,22 @@ def generate( self.agenerate(prompt, response_model) ) else: - if self.provider.lower() == "google": - google_kwargs = {} - generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"} - generation_config = {} - - for key, value in self.model_args.items(): - if key in generation_config_keys: - if key == "max_tokens": - generation_config["max_output_tokens"] = value - else: - generation_config[key] = value - else: - google_kwargs[key] = value - - if generation_config: - google_kwargs["generation_config"] = generation_config + # Map parameters based on provider requirements + provider_kwargs = self._map_provider_params() + if self.provider.lower() == "google": result = self.client.create( messages=messages, response_model=response_model, - **google_kwargs, + **provider_kwargs, ) else: - # Map parameters for OpenAI models requiring max_completion_tokens - openai_kwargs = self._map_openai_params(self.model_args) + # OpenAI, Anthropic, LiteLLM result = self.client.chat.completions.create( model=self.model, messages=messages, response_model=response_model, - **openai_kwargs, + **provider_kwargs, ) # Track the usage @@ -764,36 +824,22 @@ async def agenerate( "Cannot use agenerate() with a synchronous client. Use generate() instead." ) - if self.provider.lower() == "google": - google_kwargs = {} - generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"} - generation_config = {} - - for key, value in self.model_args.items(): - if key in generation_config_keys: - if key == "max_tokens": - generation_config["max_output_tokens"] = value - else: - generation_config[key] = value - else: - google_kwargs[key] = value - - if generation_config: - google_kwargs["generation_config"] = generation_config + # Map parameters based on provider requirements + provider_kwargs = self._map_provider_params() + if self.provider.lower() == "google": result = await self.client.create( messages=messages, response_model=response_model, - **google_kwargs, + **provider_kwargs, ) else: - # Map parameters for OpenAI models requiring max_completion_tokens - openai_kwargs = self._map_openai_params(self.model_args) + # OpenAI, Anthropic, LiteLLM result = await self.client.chat.completions.create( model=self.model, messages=messages, response_model=response_model, - **openai_kwargs, + **provider_kwargs, ) # Track the usage