@@ -11,9 +11,16 @@ import {
1111import { convertToOpenAiMessages } from "../transform/openai-format"
1212import { ApiStream } from "../transform/stream"
1313import { BaseProvider } from "./base-provider"
14+ import { calculateApiCostOpenAI } from "../../utils/cost"
1415
1516const OPENAI_NATIVE_DEFAULT_TEMPERATURE = 0
1617
18+ // Define a type for the model object returned by getModel
19+ export type OpenAiNativeModel = {
20+ id : OpenAiNativeModelId
21+ info : ModelInfo
22+ }
23+
1724export class OpenAiNativeHandler extends BaseProvider implements SingleCompletionHandler {
1825 protected options : ApiHandlerOptions
1926 private client : OpenAI
@@ -26,31 +33,31 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
2633 }
2734
2835 override async * createMessage ( systemPrompt : string , messages : Anthropic . Messages . MessageParam [ ] ) : ApiStream {
29- const modelId = this . getModel ( ) . id
36+ const model = this . getModel ( )
3037
31- if ( modelId . startsWith ( "o1" ) ) {
32- yield * this . handleO1FamilyMessage ( modelId , systemPrompt , messages )
38+ if ( model . id . startsWith ( "o1" ) ) {
39+ yield * this . handleO1FamilyMessage ( model , systemPrompt , messages )
3340 return
3441 }
3542
36- if ( modelId . startsWith ( "o3-mini" ) ) {
37- yield * this . handleO3FamilyMessage ( modelId , systemPrompt , messages )
43+ if ( model . id . startsWith ( "o3-mini" ) ) {
44+ yield * this . handleO3FamilyMessage ( model , systemPrompt , messages )
3845 return
3946 }
4047
41- yield * this . handleDefaultModelMessage ( modelId , systemPrompt , messages )
48+ yield * this . handleDefaultModelMessage ( model , systemPrompt , messages )
4249 }
4350
4451 private async * handleO1FamilyMessage (
45- modelId : string ,
52+ model : OpenAiNativeModel ,
4653 systemPrompt : string ,
4754 messages : Anthropic . Messages . MessageParam [ ] ,
4855 ) : ApiStream {
4956 // o1 supports developer prompt with formatting
5057 // o1-preview and o1-mini only support user messages
51- const isOriginalO1 = modelId === "o1"
58+ const isOriginalO1 = model . id === "o1"
5259 const response = await this . client . chat . completions . create ( {
53- model : modelId ,
60+ model : model . id ,
5461 messages : [
5562 {
5663 role : isOriginalO1 ? "developer" : "user" ,
@@ -62,11 +69,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
6269 stream_options : { include_usage : true } ,
6370 } )
6471
65- yield * this . handleStreamResponse ( response )
72+ yield * this . handleStreamResponse ( response , model )
6673 }
6774
6875 private async * handleO3FamilyMessage (
69- modelId : string ,
76+ model : OpenAiNativeModel ,
7077 systemPrompt : string ,
7178 messages : Anthropic . Messages . MessageParam [ ] ,
7279 ) : ApiStream {
@@ -84,23 +91,23 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
8491 reasoning_effort : this . getModel ( ) . info . reasoningEffort ,
8592 } )
8693
87- yield * this . handleStreamResponse ( stream )
94+ yield * this . handleStreamResponse ( stream , model )
8895 }
8996
9097 private async * handleDefaultModelMessage (
91- modelId : string ,
98+ model : OpenAiNativeModel ,
9299 systemPrompt : string ,
93100 messages : Anthropic . Messages . MessageParam [ ] ,
94101 ) : ApiStream {
95102 const stream = await this . client . chat . completions . create ( {
96- model : modelId ,
103+ model : model . id ,
97104 temperature : this . options . modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE ,
98105 messages : [ { role : "system" , content : systemPrompt } , ...convertToOpenAiMessages ( messages ) ] ,
99106 stream : true ,
100107 stream_options : { include_usage : true } ,
101108 } )
102109
103- yield * this . handleStreamResponse ( stream )
110+ yield * this . handleStreamResponse ( stream , model )
104111 }
105112
106113 private async * yieldResponseData ( response : OpenAI . Chat . Completions . ChatCompletion ) : ApiStream {
@@ -115,7 +122,10 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
115122 }
116123 }
117124
118- private async * handleStreamResponse ( stream : AsyncIterable < OpenAI . Chat . Completions . ChatCompletionChunk > ) : ApiStream {
125+ private async * handleStreamResponse (
126+ stream : AsyncIterable < OpenAI . Chat . Completions . ChatCompletionChunk > ,
127+ model : OpenAiNativeModel ,
128+ ) : ApiStream {
119129 for await ( const chunk of stream ) {
120130 const delta = chunk . choices [ 0 ] ?. delta
121131 if ( delta ?. content ) {
@@ -126,16 +136,29 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
126136 }
127137
128138 if ( chunk . usage ) {
129- yield {
130- type : "usage" ,
131- inputTokens : chunk . usage . prompt_tokens || 0 ,
132- outputTokens : chunk . usage . completion_tokens || 0 ,
133- }
139+ yield * this . yieldUsage ( model . info , chunk . usage )
134140 }
135141 }
136142 }
137143
138- override getModel ( ) : { id : OpenAiNativeModelId ; info : ModelInfo } {
144+ private async * yieldUsage ( info : ModelInfo , usage : OpenAI . Completions . CompletionUsage | undefined ) : ApiStream {
145+ const inputTokens = usage ?. prompt_tokens || 0 // sum of cache hits and misses
146+ const outputTokens = usage ?. completion_tokens || 0
147+ const cacheReadTokens = usage ?. prompt_tokens_details ?. cached_tokens || 0
148+ const cacheWriteTokens = 0
149+ const totalCost = calculateApiCostOpenAI ( info , inputTokens , outputTokens , cacheWriteTokens , cacheReadTokens )
150+ const nonCachedInputTokens = Math . max ( 0 , inputTokens - cacheReadTokens - cacheWriteTokens )
151+ yield {
152+ type : "usage" ,
153+ inputTokens : nonCachedInputTokens ,
154+ outputTokens : outputTokens ,
155+ cacheWriteTokens : cacheWriteTokens ,
156+ cacheReadTokens : cacheReadTokens ,
157+ totalCost : totalCost ,
158+ }
159+ }
160+
161+ override getModel ( ) : OpenAiNativeModel {
139162 const modelId = this . options . apiModelId
140163 if ( modelId && modelId in openAiNativeModels ) {
141164 const id = modelId as OpenAiNativeModelId
@@ -146,15 +169,15 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
146169
147170 async completePrompt ( prompt : string ) : Promise < string > {
148171 try {
149- const modelId = this . getModel ( ) . id
172+ const model = this . getModel ( )
150173 let requestOptions : OpenAI . Chat . Completions . ChatCompletionCreateParamsNonStreaming
151174
152- if ( modelId . startsWith ( "o1" ) ) {
153- requestOptions = this . getO1CompletionOptions ( modelId , prompt )
154- } else if ( modelId . startsWith ( "o3-mini" ) ) {
155- requestOptions = this . getO3CompletionOptions ( modelId , prompt )
175+ if ( model . id . startsWith ( "o1" ) ) {
176+ requestOptions = this . getO1CompletionOptions ( model , prompt )
177+ } else if ( model . id . startsWith ( "o3-mini" ) ) {
178+ requestOptions = this . getO3CompletionOptions ( model , prompt )
156179 } else {
157- requestOptions = this . getDefaultCompletionOptions ( modelId , prompt )
180+ requestOptions = this . getDefaultCompletionOptions ( model , prompt )
158181 }
159182
160183 const response = await this . client . chat . completions . create ( requestOptions )
@@ -168,17 +191,17 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
168191 }
169192
170193 private getO1CompletionOptions (
171- modelId : string ,
194+ model : OpenAiNativeModel ,
172195 prompt : string ,
173196 ) : OpenAI . Chat . Completions . ChatCompletionCreateParamsNonStreaming {
174197 return {
175- model : modelId ,
198+ model : model . id ,
176199 messages : [ { role : "user" , content : prompt } ] ,
177200 }
178201 }
179202
180203 private getO3CompletionOptions (
181- modelId : string ,
204+ model : OpenAiNativeModel ,
182205 prompt : string ,
183206 ) : OpenAI . Chat . Completions . ChatCompletionCreateParamsNonStreaming {
184207 return {
@@ -189,11 +212,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
189212 }
190213
191214 private getDefaultCompletionOptions (
192- modelId : string ,
215+ model : OpenAiNativeModel ,
193216 prompt : string ,
194217 ) : OpenAI . Chat . Completions . ChatCompletionCreateParamsNonStreaming {
195218 return {
196- model : modelId ,
219+ model : model . id ,
197220 messages : [ { role : "user" , content : prompt } ] ,
198221 temperature : this . options . modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE ,
199222 }
0 commit comments