@@ -110,46 +110,22 @@ def is_model_supported(self, model: str) -> bool:
110110 return model in _cached_list_models_from_openai (self .__api_key )
111111
112112 def __get_model_limits (self , model : str ) -> int :
113+ """Return the token limit for a given model."""
113114 return self .__MODEL_LIMITS .get (model , 128_000 )
115+
116+ def get_model_limit (self , model : str ) -> int :
117+ """
118+ Public method to get the model's context length limit.
119+
120+ Args:
121+ model: The model name
122+
123+ Returns:
124+ The maximum context length in tokens
125+ """
126+ return self .__get_model_limits (model )
127+
114128
115- def is_prompt_supported (
116- self ,
117- messages : Iterable [ChatCompletionMessageParam ],
118- model : str ,
119- frequency_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
120- logit_bias : Optional [Dict [str , int ]] | NotGiven = NOT_GIVEN ,
121- logprobs : Optional [bool ] | NotGiven = NOT_GIVEN ,
122- max_tokens : Optional [int ] | NotGiven = NOT_GIVEN ,
123- n : Optional [int ] | NotGiven = NOT_GIVEN ,
124- presence_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
125- response_format : dict | completion_create_params .ResponseFormat | NotGiven = NOT_GIVEN ,
126- stop : Union [Optional [str ], List [str ]] | NotGiven = NOT_GIVEN ,
127- temperature : Optional [float ] | NotGiven = NOT_GIVEN ,
128- tools : Iterable [ChatCompletionToolParam ] | NotGiven = NOT_GIVEN ,
129- tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
130- top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
131- top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
132- file : Path | NotGiven = NOT_GIVEN ,
133- ) -> int :
134- # might not implement model endpoint
135- if self .__is_not_openai_url ():
136- return 1
137-
138- model_limit = self .__get_model_limits (model )
139- token_count = 0
140- encoding = None
141- try :
142- encoding = tiktoken .encoding_for_model (model )
143- except Exception as e :
144- logger .error (f"Error getting encoding for model { model } : { e } , using gpt-4o as fallback" )
145- encoding = tiktoken .encoding_for_model ("gpt-4o" )
146- for message in messages :
147- message_token_count = len (encoding .encode (message .get ("content" )))
148- token_count = token_count + message_token_count
149- if token_count > model_limit :
150- return - 1
151-
152- return model_limit - token_count
153129
154130 def truncate_messages (
155131 self , messages : Iterable [ChatCompletionMessageParam ], model : str
0 commit comments