|
| 1 | +import OpenAI from 'openai'; |
| 2 | +import { ApiHandler, LLMMessageParam } from '..'; |
| 3 | +import { |
| 4 | + MoudleInfo, |
| 5 | + LLMModelId, |
| 6 | + ApiHandlerOptions, |
| 7 | + LLMModelToParameters, |
| 8 | +} from '../../shared/api'; |
| 9 | +import { convertToOpenAiMessages } from '../transform/openaiFormat'; |
| 10 | +import { ApiStream } from '../transform/stream'; |
| 11 | + |
| 12 | +export class DoubaoHandler implements ApiHandler { |
| 13 | + #options: ApiHandlerOptions; |
| 14 | + #client: OpenAI; |
| 15 | + |
| 16 | + constructor(options: ApiHandlerOptions) { |
| 17 | + this.#options = options; |
| 18 | + this.#client = new OpenAI({ |
| 19 | + baseURL: options.endpoint, |
| 20 | + apiKey: options.apiKey, |
| 21 | + }); |
| 22 | + } |
| 23 | + |
| 24 | + getModel(): { id: LLMModelId; info: MoudleInfo } { |
| 25 | + const modelId = this.#options.apiModelId; |
| 26 | + if (modelId in LLMModelToParameters) { |
| 27 | + const id = modelId as LLMModelId; |
| 28 | + return { id, info: LLMModelToParameters[id] }; |
| 29 | + } else { |
| 30 | + throw new Error(`Model ID ${modelId} is not supported.`); |
| 31 | + } |
| 32 | + } |
| 33 | + |
| 34 | + async *createMessage(systemPrompt: string, messages: LLMMessageParam[]): ApiStream { |
| 35 | + const model = this.getModel(); |
| 36 | + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ |
| 37 | + { role: 'system', content: systemPrompt }, |
| 38 | + ...convertToOpenAiMessages(messages), |
| 39 | + ]; |
| 40 | + const stream = await this.#client.chat.completions.create({ |
| 41 | + model: model.id, |
| 42 | + max_completion_tokens: model.info.max_completion_tokens, |
| 43 | + messages: openAiMessages, |
| 44 | + stream: true, |
| 45 | + stream_options: { include_usage: true }, |
| 46 | + temperature: 0, |
| 47 | + }); |
| 48 | + |
| 49 | + for await (const chunk of stream) { |
| 50 | + const delta = chunk.choices[0]?.delta; |
| 51 | + if (delta?.content) { |
| 52 | + yield { |
| 53 | + type: 'text', |
| 54 | + text: delta.content, |
| 55 | + }; |
| 56 | + } |
| 57 | + |
| 58 | + if (chunk.usage) { |
| 59 | + yield { |
| 60 | + type: 'usage', |
| 61 | + inputTokens: chunk.usage.prompt_tokens || 0, |
| 62 | + outputTokens: chunk.usage.completion_tokens || 0, |
| 63 | + // @ts-ignore-next-line |
| 64 | + cacheReadTokens: chunk.usage.prompt_cache_hit_tokens || 0, |
| 65 | + // @ts-ignore-next-line |
| 66 | + cacheWriteTokens: chunk.usage.prompt_cache_miss_tokens || 0, |
| 67 | + }; |
| 68 | + } |
| 69 | + } |
| 70 | + } |
| 71 | +} |
0 commit comments