@@ -155,7 +155,9 @@ def __init__(
155
155
self .stream : bool = False
156
156
if model is None :
157
157
models = self ._openai .models .list ().data
158
- if self ._is_valid_model ("gpt-4" , models = models ):
158
+ if self ._is_valid_model ("gpt-4-turbo" , models = models ):
159
+ self .model = "gpt-4-turbo"
160
+ elif self ._is_valid_model ("gpt-4" , models = models ):
159
161
self .model = "gpt-4"
160
162
elif self ._is_valid_model ("gpt-3.5-turbo" , models = models ):
161
163
self .model = "gpt-3.5-turbo"
@@ -230,7 +232,11 @@ def cost_cents(self) -> int:
230
232
raise CostEstimateUnavailableError (
231
233
"Unable to calculate token usage"
232
234
)
233
- if self .model .startswith ("gpt-4" ):
235
+ if self .model .startswith ("gpt-4-turbo" ):
236
+ return (1 * (self .prompt_tokens // 1000 )) + (
237
+ 3 * (self .sampled_tokens // 1000 )
238
+ )
239
+ elif self .model .startswith ("gpt-4" ):
234
240
return (3 * (self .prompt_tokens // 1000 )) + (
235
241
6 * (self .sampled_tokens // 1000 )
236
242
)
@@ -368,6 +374,7 @@ def _post_send(self, resp, stream_cls: Type[S]) -> Union[Message, S]:
368
374
if resp .model not in (
369
375
"gpt-4-0314" ,
370
376
"gpt-4-0613" ,
377
+ "gpt-4-turbo-2024-04-09" ,
371
378
) and resp .model .startswith ("gpt-4" ):
372
379
self .prompt_tokens = None
373
380
self .sampled_tokens = None
0 commit comments