@@ -744,6 +744,110 @@ def __init__(
744744 # Check if client is async-capable at initialization
745745 self .is_async = self ._check_client_async ()
746746
747+ def _map_provider_params (self ) -> t .Dict [str , t .Any ]:
748+ """Route to provider-specific parameter mapping.
749+
750+ Each provider may have different parameter requirements:
751+ - Google: Wraps parameters in generation_config and renames max_tokens
752+ - OpenAI: Maps max_tokens to max_completion_tokens for o-series models
753+ - Anthropic: No special handling required (pass-through)
754+ - LiteLLM: No special handling required (routes internally, pass-through)
755+ """
756+ provider_lower = self .provider .lower ()
757+
758+ if provider_lower == "google" :
759+ return self ._map_google_params ()
760+ elif provider_lower == "openai" :
761+ return self ._map_openai_params ()
762+ else :
763+ # Anthropic, LiteLLM - pass through unchanged
764+ return self .model_args .copy ()
765+
766+ def _map_openai_params (self ) -> t .Dict [str , t .Any ]:
767+ """Map max_tokens to max_completion_tokens for OpenAI reasoning models.
768+
769+ Reasoning models (o-series and gpt-5 series) require max_completion_tokens
770+ instead of the deprecated max_tokens parameter when using Chat Completions API.
771+
772+ Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
773+
774+ Pattern-based matching for future-proof coverage:
775+ - O-series: o1, o2, o3, o4, o5, ... (all reasoning versions)
776+ - GPT-5 series: gpt-5, gpt-5-*, gpt-6, gpt-7, ... (all GPT-5+ models)
777+ - Other: codex-mini
778+ """
779+ mapped_args = self .model_args .copy ()
780+
781+ model_lower = self .model .lower ()
782+
783+ # Pattern-based detection for reasoning models that require max_completion_tokens
784+ # Uses prefix matching to cover current and future model variants
785+ def is_reasoning_model (model_str : str ) -> bool :
786+ """Check if model is a reasoning model requiring max_completion_tokens."""
787+ # O-series reasoning models (o1, o1-mini, o1-2024-12-17, o2, o3, o4, o5, o6, o7, o8, o9)
788+ # Pattern: "o" followed by single digit 1-9, then optional "-" or end of string
789+ # TODO: Update to support o10+ when OpenAI releases models beyond o9
790+ if (
791+ len (model_str ) >= 2
792+ and model_str [0 ] == "o"
793+ and model_str [1 ] in "123456789"
794+ ):
795+ # Allow single digit o-series: o1, o2, ..., o9
796+ if len (model_str ) == 2 or model_str [2 ] in ("-" , "_" ):
797+ return True
798+
799+ # GPT-5 and newer generation models (gpt-5, gpt-5-*, gpt-6, gpt-7, ..., gpt-19)
800+ # Pattern: "gpt-" followed by single or double digit >= 5, max 19
801+ # TODO: Update to support gpt-20+ when OpenAI releases models beyond gpt-19
802+ if model_str .startswith ("gpt-" ):
803+ version_str = (
804+ model_str [4 :].split ("-" )[0 ].split ("_" )[0 ]
805+ ) # Get version number
806+ try :
807+ version = int (version_str )
808+ if 5 <= version <= 19 :
809+ return True
810+ except ValueError :
811+ pass
812+
813+ # Other specific reasoning models
814+ if model_str == "codex-mini" :
815+ return True
816+
817+ return False
818+
819+ requires_max_completion_tokens = is_reasoning_model (model_lower )
820+
821+ # If max_tokens is provided and model requires max_completion_tokens, map it
822+ if requires_max_completion_tokens and "max_tokens" in mapped_args :
823+ mapped_args ["max_completion_tokens" ] = mapped_args .pop ("max_tokens" )
824+
825+ return mapped_args
826+
827+ def _map_google_params (self ) -> t .Dict [str , t .Any ]:
828+ """Map parameters for Google Gemini models.
829+
830+ Google models require parameters to be wrapped in a generation_config dict,
831+ and max_tokens is renamed to max_output_tokens.
832+ """
833+ google_kwargs = {}
834+ generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
835+ generation_config = {}
836+
837+ for key , value in self .model_args .items ():
838+ if key in generation_config_keys :
839+ if key == "max_tokens" :
840+ generation_config ["max_output_tokens" ] = value
841+ else :
842+ generation_config [key ] = value
843+ else :
844+ google_kwargs [key ] = value
845+
846+ if generation_config :
847+ google_kwargs ["generation_config" ] = generation_config
848+
849+ return google_kwargs
850+
747851 def _check_client_async (self ) -> bool :
748852 """Determine if the client is async-capable."""
749853 try :
@@ -826,34 +930,22 @@ def generate(
826930 self .agenerate (prompt , response_model )
827931 )
828932 else :
829- if self .provider .lower () == "google" :
830- google_kwargs = {}
831- generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
832- generation_config = {}
833-
834- for key , value in self .model_args .items ():
835- if key in generation_config_keys :
836- if key == "max_tokens" :
837- generation_config ["max_output_tokens" ] = value
838- else :
839- generation_config [key ] = value
840- else :
841- google_kwargs [key ] = value
842-
843- if generation_config :
844- google_kwargs ["generation_config" ] = generation_config
933+ # Map parameters based on provider requirements
934+ provider_kwargs = self ._map_provider_params ()
845935
936+ if self .provider .lower () == "google" :
846937 result = self .client .create (
847938 messages = messages ,
848939 response_model = response_model ,
849- ** google_kwargs ,
940+ ** provider_kwargs ,
850941 )
851942 else :
943+ # OpenAI, Anthropic, LiteLLM
852944 result = self .client .chat .completions .create (
853945 model = self .model ,
854946 messages = messages ,
855947 response_model = response_model ,
856- ** self . model_args ,
948+ ** provider_kwargs ,
857949 )
858950
859951 # Track the usage
@@ -882,34 +974,22 @@ async def agenerate(
882974 "Cannot use agenerate() with a synchronous client. Use generate() instead."
883975 )
884976
885- if self .provider .lower () == "google" :
886- google_kwargs = {}
887- generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
888- generation_config = {}
889-
890- for key , value in self .model_args .items ():
891- if key in generation_config_keys :
892- if key == "max_tokens" :
893- generation_config ["max_output_tokens" ] = value
894- else :
895- generation_config [key ] = value
896- else :
897- google_kwargs [key ] = value
898-
899- if generation_config :
900- google_kwargs ["generation_config" ] = generation_config
977+ # Map parameters based on provider requirements
978+ provider_kwargs = self ._map_provider_params ()
901979
980+ if self .provider .lower () == "google" :
902981 result = await self .client .create (
903982 messages = messages ,
904983 response_model = response_model ,
905- ** google_kwargs ,
984+ ** provider_kwargs ,
906985 )
907986 else :
987+ # OpenAI, Anthropic, LiteLLM
908988 result = await self .client .chat .completions .create (
909989 model = self .model ,
910990 messages = messages ,
911991 response_model = response_model ,
912- ** self . model_args ,
992+ ** provider_kwargs ,
913993 )
914994
915995 # Track the usage
0 commit comments