@@ -594,36 +594,110 @@ def __init__(
594594 # Check if client is async-capable at initialization
595595 self .is_async = self ._check_client_async ()
596596
597- def _map_openai_params (self , model_args : t .Dict [str , t .Any ]) -> t .Dict [str , t .Any ]:
598- """Map max_tokens to max_completion_tokens for o-series and newer OpenAI models.
597+ def _map_provider_params (self ) -> t .Dict [str , t .Any ]:
598+ """Route to provider-specific parameter mapping.
599+
600+ Each provider may have different parameter requirements:
601+ - Google: Wraps parameters in generation_config and renames max_tokens
602+ - OpenAI: Maps max_tokens to max_completion_tokens for o-series models
603+ - Anthropic: No special handling required (pass-through)
604+ - LiteLLM: No special handling required (routes internally, pass-through)
605+ """
606+ provider_lower = self .provider .lower ()
607+
608+ if provider_lower == "google" :
609+ return self ._map_google_params ()
610+ elif provider_lower == "openai" :
611+ return self ._map_openai_params ()
612+ else :
613+ # Anthropic, LiteLLM - pass through unchanged
614+ return self .model_args .copy ()
615+
616+ def _map_openai_params (self ) -> t .Dict [str , t .Any ]:
617+ """Map max_tokens to max_completion_tokens for OpenAI reasoning models.
599618
600- O-series models (o1, o3, etc.) and some newer models like gpt-5-mini
601- require max_completion_tokens instead of the deprecated max_tokens parameter.
619+ Reasoning models (o-series and gpt-5 series) require max_completion_tokens
620+ instead of the deprecated max_tokens parameter when using Chat Completions API.
621+
622+ Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
623+
624+ Pattern-based matching for future-proof coverage:
625+ - O-series: o1, o2, o3, o4, o5, ... (all reasoning versions)
626+ - GPT-5 series: gpt-5, gpt-5-*, gpt-6, gpt-7, ... (all GPT-5+ models)
627+ - Other: codex-mini
602628 """
603- mapped_args = model_args .copy ()
604-
605- # List of models that require max_completion_tokens
606- models_requiring_max_completion_tokens = [
607- "o1" ,
608- "o3" ,
609- "o1-mini" ,
610- "o3-mini" ,
611- "gpt-5" ,
612- "gpt-5-mini" ,
613- ]
629+ mapped_args = self .model_args .copy ()
614630
615- # Check if the model matches any of the patterns
616631 model_lower = self .model .lower ()
617- requires_max_completion_tokens = any (
618- pattern in model_lower for pattern in models_requiring_max_completion_tokens
619- )
632+
633+ # Pattern-based detection for reasoning models that require max_completion_tokens
634+ # Uses prefix matching to cover current and future model variants
635+ def is_reasoning_model (model_str : str ) -> bool :
636+ """Check if model is a reasoning model requiring max_completion_tokens."""
637+ # O-series reasoning models (o1, o1-mini, o1-2024-12-17, o2, o3, o4, o5, o6, o7, o8, o9)
638+ # Pattern: "o" followed by single digit 1-9, then optional "-" or end of string
639+ # TODO: Update to support o10+ when OpenAI releases models beyond o9
640+ if (
641+ len (model_str ) >= 2
642+ and model_str [0 ] == "o"
643+ and model_str [1 ] in "123456789"
644+ ):
645+ # Allow single digit o-series: o1, o2, ..., o9
646+ if len (model_str ) == 2 or model_str [2 ] in ("-" , "_" ):
647+ return True
648+
649+ # GPT-5 and newer generation models (gpt-5, gpt-5-*, gpt-6, gpt-7, ..., gpt-19)
650+ # Pattern: "gpt-" followed by single or double digit >= 5, max 19
651+ # TODO: Update to support gpt-20+ when OpenAI releases models beyond gpt-19
652+ if model_str .startswith ("gpt-" ):
653+ version_str = (
654+ model_str [4 :].split ("-" )[0 ].split ("_" )[0 ]
655+ ) # Get version number
656+ try :
657+ version = int (version_str )
658+ if 5 <= version <= 19 :
659+ return True
660+ except ValueError :
661+ pass
662+
663+ # Other specific reasoning models
664+ if model_str == "codex-mini" :
665+ return True
666+
667+ return False
668+
669+ requires_max_completion_tokens = is_reasoning_model (model_lower )
620670
621671 # If max_tokens is provided and model requires max_completion_tokens, map it
622672 if requires_max_completion_tokens and "max_tokens" in mapped_args :
623673 mapped_args ["max_completion_tokens" ] = mapped_args .pop ("max_tokens" )
624674
625675 return mapped_args
626676
677+ def _map_google_params (self ) -> t .Dict [str , t .Any ]:
678+ """Map parameters for Google Gemini models.
679+
680+ Google models require parameters to be wrapped in a generation_config dict,
681+ and max_tokens is renamed to max_output_tokens.
682+ """
683+ google_kwargs = {}
684+ generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
685+ generation_config = {}
686+
687+ for key , value in self .model_args .items ():
688+ if key in generation_config_keys :
689+ if key == "max_tokens" :
690+ generation_config ["max_output_tokens" ] = value
691+ else :
692+ generation_config [key ] = value
693+ else :
694+ google_kwargs [key ] = value
695+
696+ if generation_config :
697+ google_kwargs ["generation_config" ] = generation_config
698+
699+ return google_kwargs
700+
627701 def _check_client_async (self ) -> bool :
628702 """Determine if the client is async-capable."""
629703 try :
@@ -706,36 +780,22 @@ def generate(
706780 self .agenerate (prompt , response_model )
707781 )
708782 else :
709- if self .provider .lower () == "google" :
710- google_kwargs = {}
711- generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
712- generation_config = {}
713-
714- for key , value in self .model_args .items ():
715- if key in generation_config_keys :
716- if key == "max_tokens" :
717- generation_config ["max_output_tokens" ] = value
718- else :
719- generation_config [key ] = value
720- else :
721- google_kwargs [key ] = value
722-
723- if generation_config :
724- google_kwargs ["generation_config" ] = generation_config
783+ # Map parameters based on provider requirements
784+ provider_kwargs = self ._map_provider_params ()
725785
786+ if self .provider .lower () == "google" :
726787 result = self .client .create (
727788 messages = messages ,
728789 response_model = response_model ,
729- ** google_kwargs ,
790+ ** provider_kwargs ,
730791 )
731792 else :
732- # Map parameters for OpenAI models requiring max_completion_tokens
733- openai_kwargs = self ._map_openai_params (self .model_args )
793+ # OpenAI, Anthropic, LiteLLM
734794 result = self .client .chat .completions .create (
735795 model = self .model ,
736796 messages = messages ,
737797 response_model = response_model ,
738- ** openai_kwargs ,
798+ ** provider_kwargs ,
739799 )
740800
741801 # Track the usage
@@ -764,36 +824,22 @@ async def agenerate(
764824 "Cannot use agenerate() with a synchronous client. Use generate() instead."
765825 )
766826
767- if self .provider .lower () == "google" :
768- google_kwargs = {}
769- generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
770- generation_config = {}
771-
772- for key , value in self .model_args .items ():
773- if key in generation_config_keys :
774- if key == "max_tokens" :
775- generation_config ["max_output_tokens" ] = value
776- else :
777- generation_config [key ] = value
778- else :
779- google_kwargs [key ] = value
780-
781- if generation_config :
782- google_kwargs ["generation_config" ] = generation_config
827+ # Map parameters based on provider requirements
828+ provider_kwargs = self ._map_provider_params ()
783829
830+ if self .provider .lower () == "google" :
784831 result = await self .client .create (
785832 messages = messages ,
786833 response_model = response_model ,
787- ** google_kwargs ,
834+ ** provider_kwargs ,
788835 )
789836 else :
790- # Map parameters for OpenAI models requiring max_completion_tokens
791- openai_kwargs = self ._map_openai_params (self .model_args )
837+ # OpenAI, Anthropic, LiteLLM
792838 result = await self .client .chat .completions .create (
793839 model = self .model ,
794840 messages = messages ,
795841 response_model = response_model ,
796- ** openai_kwargs ,
842+ ** provider_kwargs ,
797843 )
798844
799845 # Track the usage
0 commit comments