diff --git a/package.json b/package.json index fc3cd83..967362b 100644 --- a/package.json +++ b/package.json @@ -34,11 +34,13 @@ "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "0.24.3", + "@anthropic-ai/vertex-sdk": "^0.4.1", "@aws-sdk/client-bedrock-runtime": "3.609.0", "@google/generative-ai": "0.14.1", "@mistralai/mistralai": "0.5.0", "chalk": "^4.1.2", "cohere-ai": "7.10.6", + "google-auth-library": "^9.2.0", "mime-types": "^2.1.35", "nanoid": "^5.0.7", "openai": "4.91.1" diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9aa287b..1ff9820 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -8,6 +8,9 @@ dependencies: '@anthropic-ai/sdk': specifier: 0.24.3 version: 0.24.3 + '@anthropic-ai/vertex-sdk': + specifier: ^0.4.1 + version: 0.4.1 '@aws-sdk/client-bedrock-runtime': specifier: 3.609.0 version: 3.609.0 @@ -23,6 +26,9 @@ dependencies: cohere-ai: specifier: 7.10.6 version: 7.10.6(@aws-sdk/client-sso-oidc@3.693.0) + google-auth-library: + specifier: ^9.2.0 + version: 9.2.0 mime-types: specifier: ^2.1.35 version: 2.1.35 @@ -135,6 +141,16 @@ packages: - encoding dev: false + /@anthropic-ai/vertex-sdk@0.4.1: + resolution: {integrity: sha512-RT/2CWzqyAcJDZWxnNc1mXa7XiiHDaQ9aknfW4mIDw6zE+Zj/R2vCKpTb0dIwrmHYNOyKQNaD7Z1ynDt9oXFWA==} + dependencies: + '@anthropic-ai/sdk': 0.24.3 + google-auth-library: 9.15.1 + transitivePeerDependencies: + - encoding + - supports-color + dev: false + /@aws-crypto/crc32@3.0.0: resolution: {integrity: sha512-IzSgsrxUcsrejQbPVilIKy16kAT52EwB6zSaI+M3xxIhKh5+aldEyvI+z6erM7TCLB2BJsFrtHjp6/4/sr+3dA==} dependencies: @@ -3446,6 +3462,11 @@ packages: hasBin: true dev: true + /agent-base@7.1.3: + resolution: {integrity: sha512-jRR5wdylq8CkOe6hei19GGZnxM6rBGwFl3Bg0YItGDimvjGtAvdZk4Pu6Cl4u4Igsws4a1fd1Vq3ezrhn4KmFw==} + engines: {node: '>= 14'} + dev: false + /agentkeepalive@4.5.0: resolution: {integrity: sha512-5GG/5IbQQpC9FpkRGsSvZI5QYeSCzlJHdpBQntCsuTOxhKD8lqKhrleg2Yi7yvMIf82Ycmmqln9U8V9qwEiJew==} engines: {node: '>= 8.0.0'} @@ -3706,6 +3727,10 @@ packages: resolution: {integrity: sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==} dev: false + /bignumber.js@9.3.0: + resolution: {integrity: sha512-EM7aMFTXbptt/wZdMlBv2t8IViwQL+h6SLHosp8Yf0dqJMTnY6iL32opnAB6kAdL0SZPuvcAzFr31o0c/R3/RA==} + dev: false + /bowser@2.11.0: resolution: {integrity: sha512-AlcaJBi/pqqJBIQ8U9Mcpc9i8Aqxn88Skv5d+xBX006BY5u8N3mGLHa5Lgppa7L/HfwgwLgZ6NYs+Ag6uUmJRA==} dev: false @@ -3747,6 +3772,10 @@ packages: node-int64: 0.4.0 dev: true + /buffer-equal-constant-time@1.0.1: + resolution: {integrity: sha512-zRpUiDwd/xk6ADqPMATG8vc9VPrkck7T07OIx0gnjmJAnHnTVXNQG3vfvWNuiZIkwu9KrKdA1iJKfsfTVxE6NA==} + dev: false + /buffer-from@1.1.2: resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==} dev: true @@ -3985,7 +4014,6 @@ packages: optional: true dependencies: ms: 2.1.3 - dev: true /dedent@1.5.3: resolution: {integrity: sha512-NHQtfOOW68WD8lgypbLA5oT+Bt0xXJhiYvoR6SmmNXZfpzOGXwdKWmcwG8N7PwVVWV3eF/68nmD9BaJSsTBhyQ==} @@ -4072,6 +4100,12 @@ packages: resolution: {integrity: sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==} dev: true + /ecdsa-sig-formatter@1.0.11: + resolution: {integrity: sha512-nagl3RYrbNv6kQkeJIpt6NJZy8twLB/2vtz6yN9Z4vRKHN4/QZJIEbqohALSgwKdnksuY3k5Addp5lg8sVoVcQ==} + dependencies: + safe-buffer: 5.2.1 + dev: false + /electron-to-chromium@1.5.62: resolution: {integrity: sha512-t8c+zLmJHa9dJy96yBZRXGQYoiCEnHYgFwn1asvSPZSUdVxnB62A4RASd7k41ytG3ErFBA0TpHlKg9D9SQBmLg==} dev: true @@ -4756,6 +4790,10 @@ packages: jest-util: 29.7.0 dev: true + /extend@3.0.2: + resolution: {integrity: sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==} + dev: false + /fast-deep-equal@3.1.3: resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} dev: true @@ -4950,6 +4988,32 @@ packages: resolution: {integrity: sha512-xckBUXyTIqT97tq2x2AMb+g163b5JFysYk0x4qxNFwbfQkmNZoiRHb6sPzI9/QV33WeuvVYBUIiD4NzNIyqaRQ==} dev: true + /gaxios@6.7.1: + resolution: {integrity: sha512-LDODD4TMYx7XXdpwxAVRAIAuB0bzv0s+ywFonY46k126qzQHT9ygyoa9tncmOiQmmDrik65UYsEkv3lbfqQ3yQ==} + engines: {node: '>=14'} + dependencies: + extend: 3.0.2 + https-proxy-agent: 7.0.6 + is-stream: 2.0.1 + node-fetch: 2.7.0 + uuid: 9.0.1 + transitivePeerDependencies: + - encoding + - supports-color + dev: false + + /gcp-metadata@6.1.1: + resolution: {integrity: sha512-a4tiq7E0/5fTjxPAaH4jpjkSv/uCaU2p5KC6HVGrvl0cDjA8iBZv4vv1gyzlmK0ZUKqwpOyQMKzZQe3lTit77A==} + engines: {node: '>=14'} + dependencies: + gaxios: 6.7.1 + google-logging-utils: 0.0.2 + json-bigint: 1.0.0 + transitivePeerDependencies: + - encoding + - supports-color + dev: false + /gensync@1.0.0-beta.2: resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==} engines: {node: '>=6.9.0'} @@ -5065,6 +5129,41 @@ packages: slash: 3.0.0 dev: true + /google-auth-library@9.15.1: + resolution: {integrity: sha512-Jb6Z0+nvECVz+2lzSMt9u98UsoakXxA2HGHMCxh+so3n90XgYWkq5dur19JAJV7ONiJY22yBTyJB1TSkvPq9Ng==} + engines: {node: '>=14'} + dependencies: + base64-js: 1.5.1 + ecdsa-sig-formatter: 1.0.11 + gaxios: 6.7.1 + gcp-metadata: 6.1.1 + gtoken: 7.1.0 + jws: 4.0.0 + transitivePeerDependencies: + - encoding + - supports-color + dev: false + + /google-auth-library@9.2.0: + resolution: {integrity: sha512-1oV3p0JhNEhVbj26eF3FAJcv9MXXQt4S0wcvKZaDbl4oHq5V3UJoSbsGZGQNcjoCdhW4kDSwOs11wLlHog3fgQ==} + engines: {node: '>=14'} + dependencies: + base64-js: 1.5.1 + ecdsa-sig-formatter: 1.0.11 + gaxios: 6.7.1 + gcp-metadata: 6.1.1 + gtoken: 7.1.0 + jws: 4.0.0 + transitivePeerDependencies: + - encoding + - supports-color + dev: false + + /google-logging-utils@0.0.2: + resolution: {integrity: sha512-NEgUnEcBiP5HrPzufUkBzJOD/Sxsco3rLNo1F1TNf7ieU8ryUzBhqba8r756CjLX7rn3fHl6iLEwPYuqpoKgQQ==} + engines: {node: '>=14'} + dev: false + /gopd@1.0.1: resolution: {integrity: sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==} dependencies: @@ -5078,6 +5177,17 @@ packages: resolution: {integrity: sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==} dev: true + /gtoken@7.1.0: + resolution: {integrity: sha512-pCcEwRi+TKpMlxAQObHDQ56KawURgyAf6jtIY046fJ5tIv3zDe/LEIubckAO8fj6JnAxLdmWkUfNyulQ2iKdEw==} + engines: {node: '>=14.0.0'} + dependencies: + gaxios: 6.7.1 + jws: 4.0.0 + transitivePeerDependencies: + - encoding + - supports-color + dev: false + /has-bigints@1.0.2: resolution: {integrity: sha512-tSvCKtBr9lkF0Ex0aQiP9N+OpV4zi2r/Nee5VkRDbaqv35RLYMzbwQfFSZZH0kR+Rd6302UJZ2p/bJCEoR3VoQ==} dev: true @@ -5120,6 +5230,16 @@ packages: resolution: {integrity: sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==} dev: true + /https-proxy-agent@7.0.6: + resolution: {integrity: sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==} + engines: {node: '>= 14'} + dependencies: + agent-base: 7.1.3 + debug: 4.3.7 + transitivePeerDependencies: + - supports-color + dev: false + /human-signals@2.1.0: resolution: {integrity: sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==} engines: {node: '>=10.17.0'} @@ -5345,7 +5465,6 @@ packages: /is-stream@2.0.1: resolution: {integrity: sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==} engines: {node: '>=8'} - dev: true /is-string@1.0.7: resolution: {integrity: sha512-tE2UXzivje6ofPW7l23cjDOMa09gb7xlAqG6jG5ej6uPV32TlWP3NKPigtaGeHNu9fohccRYvIiZMfOOnOYUtg==} @@ -5933,6 +6052,12 @@ packages: hasBin: true dev: true + /json-bigint@1.0.0: + resolution: {integrity: sha512-SiPv/8VpZuWbvLSMtTDU8hEfrZWg/mH/nV/b4o0CYbSxu1UIQPLdwKOCIyLQX+VIPO5vrLX3i8qtqFyhdPSUSQ==} + dependencies: + bignumber.js: 9.3.0 + dev: false + /json-buffer@3.0.1: resolution: {integrity: sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==} dev: true @@ -5972,6 +6097,21 @@ packages: object.values: 1.2.0 dev: true + /jwa@2.0.1: + resolution: {integrity: sha512-hRF04fqJIP8Abbkq5NKGN0Bbr3JxlQ+qhZufXVr0DvujKy93ZCbXZMHDL4EOtodSbCWxOqR8MS1tXA5hwqCXDg==} + dependencies: + buffer-equal-constant-time: 1.0.1 + ecdsa-sig-formatter: 1.0.11 + safe-buffer: 5.2.1 + dev: false + + /jws@4.0.0: + resolution: {integrity: sha512-KDncfTmOZoOMTFG4mBlG0qUIOlc03fmzH+ru6RgYVZhPkyiy/92Owlt/8UEN+a4TXR1FQetfIpJE8ApdvdVxTg==} + dependencies: + jwa: 2.0.1 + safe-buffer: 5.2.1 + dev: false + /keyv@4.5.4: resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==} dependencies: diff --git a/src/chat/index.ts b/src/chat/index.ts index bc1e8ad..41658c5 100644 --- a/src/chat/index.ts +++ b/src/chat/index.ts @@ -11,6 +11,8 @@ import { export type OpenAIModel = (typeof models.openai.models)[number] export type AI21Model = (typeof models.ai21.models)[number] export type AnthropicModel = (typeof models.anthropic.models)[number] +export type AnthropicVertexModel = + (typeof models)['anthropic-vertex']['models'][number] export type GeminiModel = (typeof models.gemini.models)[number] export type CohereModel = (typeof models.cohere.models)[number] export type BedrockModel = (typeof models.bedrock.models)[number] @@ -24,6 +26,7 @@ export type LLMChatModel = | OpenAIModel | AI21Model | AnthropicModel + | AnthropicVertexModel | GeminiModel | CohereModel | BedrockModel @@ -39,6 +42,7 @@ type ProviderModelMap = { openai: OpenAIModel ai21: AI21Model anthropic: AnthropicModel + 'anthropic-vertex': AnthropicVertexModel gemini: GeminiModel cohere: CohereModel bedrock: BedrockModel diff --git a/src/handlers/anthropic-vertex.ts b/src/handlers/anthropic-vertex.ts new file mode 100644 index 0000000..e1ae69b --- /dev/null +++ b/src/handlers/anthropic-vertex.ts @@ -0,0 +1,644 @@ +import { + ContentBlock, + ImageBlockParam, + Message, + MessageCreateParamsNonStreaming, + MessageCreateParamsStreaming, + MessageStream, + TextBlock, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlock, + ToolUseBlockParam, +} from '@anthropic-ai/sdk/resources/messages' +import { AnthropicVertex } from '@anthropic-ai/vertex-sdk' +import { CredentialBody, GoogleAuth } from 'google-auth-library' +import { ChatCompletionMessageToolCall } from 'openai/resources/index' + +import { + AnthropicVertexModel, + CompletionParams, + ProviderCompletionParams, +} from '../chat/index.js' +import { + CompletionResponse, + CompletionResponseChunk, + StreamCompletionResponse, +} from '../userTypes/index.js' +import { BaseHandler } from './base.js' +import { InputError, InvariantError } from './types.js' +import { + consoleWarn, + convertMessageContentToString, + fetchThenParseImage, + getTimestamp, + isEmptyObject, +} from './utils.js' + +export const createCompletionResponseNonStreaming = ( + response: Message, + created: number, + toolChoice: CompletionParams['tool_choice'] +): CompletionResponse => { + const finishReason = toFinishReasonNonStreaming(response.stop_reason) + const chatMessage = toChatCompletionChoiceMessage( + response.content, + response.role, + toolChoice + ) + const choice = { + index: 0, + logprobs: null, + message: chatMessage, + finish_reason: finishReason, + } + const converted: CompletionResponse = { + id: response.id, + choices: [choice], + created, + model: response.model, + object: 'chat.completion', + usage: { + prompt_tokens: response.usage.input_tokens, + completion_tokens: response.usage.output_tokens, + total_tokens: response.usage.input_tokens + response.usage.output_tokens, + }, + } + + return converted +} + +export async function* createCompletionResponseStreaming( + response: MessageStream, + created: number +): StreamCompletionResponse { + let message: Message | undefined + + // We manually keep track of the tool call index because some providers, like Anthropic, start + // with a tool call index of 1 because they're preceded by a text block that has an index of 0 in + // the `response`. Since OpenAI's tool call index starts with 0, we also enforce that convention + // here for consistency. + let initialToolCallIndex: number | null = null + + let inputTokens = 0 + let outputTokens = 0 + + for await (const chunk of response) { + if (chunk.type === 'message_start') { + inputTokens = chunk.message.usage.input_tokens + + message = chunk.message + // Yield the first element + yield { + choices: [ + { + index: 0, + finish_reason: toFinishReasonStreaming(chunk.message.stop_reason), + logprobs: null, + delta: { + role: chunk.message.role, + }, + }, + ], + created, + model: message.model, + id: message.id, + object: 'chat.completion.chunk', + } + } + + if (message === undefined) { + throw new InvariantError(`Message is undefined.`) + } + + let newStopReason: Message['stop_reason'] | undefined + + let delta: CompletionResponseChunk['choices'][0]['delta'] = {} + if (chunk.type === 'content_block_start') { + if (chunk.content_block.type === 'text') { + delta = { + content: chunk.content_block.text, + } + } else { + if (initialToolCallIndex === null) { + initialToolCallIndex = chunk.index + } + + delta = { + tool_calls: [ + { + index: chunk.index - initialToolCallIndex, + id: chunk.content_block?.id, + type: 'function', + function: { + name: chunk.content_block.name, + arguments: isEmptyObject(chunk.content_block.input) + ? '' + : JSON.stringify(chunk.content_block.input), + }, + }, + ], + } + } + } else if (chunk.type === 'content_block_delta') { + if (chunk.delta.type === 'input_json_delta') { + if (initialToolCallIndex === null) { + // We assign the initial tool call index in the `content_block_start` event, which should + // always come before a `content_block_delta` event, so this variable should never be null. + throw new InvariantError( + `Content block delta event came before a content block start event.` + ) + } + + delta = { + tool_calls: [ + { + index: chunk.index - initialToolCallIndex, + function: { + arguments: chunk.delta.partial_json, + }, + }, + ], + } + } else { + delta = { + content: chunk.delta.text, + } + } + } else if (chunk.type === 'message_delta') { + newStopReason = chunk.delta.stop_reason + outputTokens = chunk.usage.output_tokens + } + + const stopReason = + newStopReason !== undefined ? newStopReason : message.stop_reason + const finishReason = toFinishReasonStreaming(stopReason) + + const chunkChoice = { + index: 0, + finish_reason: finishReason, + logprobs: null, + delta, + } + + yield { + choices: [chunkChoice], + created, + model: message.model, + id: message.id, + object: 'chat.completion.chunk', + ...(chunk.type === 'message_stop' && { + usage: { + prompt_tokens: inputTokens, + completion_tokens: outputTokens, + total_tokens: inputTokens + outputTokens, + }, + }), + } + } +} + +const isTextBlock = (contentBlock: ContentBlock): contentBlock is TextBlock => { + return contentBlock.type === 'text' +} + +const isToolUseBlock = ( + contentBlock: ContentBlock +): contentBlock is ToolUseBlock => { + return contentBlock.type === 'tool_use' +} + +const toChatCompletionChoiceMessage = ( + content: Message['content'], + role: Message['role'], + toolChoice: CompletionParams['tool_choice'] +): CompletionResponse['choices'][0]['message'] => { + const textBlocks = content.filter(isTextBlock) + if (textBlocks.length > 1) { + consoleWarn( + `Received multiple text blocks from Anthropic, which is unexpected. Concatenating the text blocks into a single string.` + ) + } + + let toolUseBlocks: ToolUseBlock[] + if (typeof toolChoice !== 'string' && toolChoice?.type === 'function') { + // When the user-defined tool_choice type is 'function', OpenAI always returns a single tool use + // block, but Anthropic can return multiple tool use blocks. We select just one of these blocks + // to conform to OpenAI's API. + const selected = content + .filter(isToolUseBlock) + .find((block) => block.name === toolChoice.function.name) + if (!selected) { + throw new InvariantError( + `Did not receive a tool use block from Anthropic for the function: ${toolChoice.function.name}` + ) + } + toolUseBlocks = [selected] + } else { + toolUseBlocks = content.filter(isToolUseBlock) + } + + let toolCalls: Array | undefined + if (toolUseBlocks.length > 0) { + toolCalls = toolUseBlocks.map((toolUse) => { + return { + id: toolUse.id, + function: { + name: toolUse.name, + arguments: JSON.stringify(toolUse.input), + }, + type: 'function', + } + }) + } + + if (textBlocks.length === 0) { + // There can be zero text blocks if either of these scenarios happen: + // - A stop sequence is immediately hit, in which case Anthropic's `content` array is empty. In this + // scenario, OpenAI returns an empty string `content` field. + // - There's only tool call responses. In this scenario, OpenAI returns a `content` field of `null`. + const messageContent = content.every(isToolUseBlock) ? null : '' + return { + role, + refusal: null, + content: messageContent, + tool_calls: toolCalls, + } + } else { + return { + role, + refusal: null, + content: textBlocks.map((textBlock) => textBlock.text).join('\n'), + tool_calls: toolCalls, + } + } +} + +const toFinishReasonNonStreaming = ( + stopReason: Message['stop_reason'] +): CompletionResponse['choices'][0]['finish_reason'] => { + if (stopReason === null) { + // Anthropic's documentation says that the `stop_reason` will never be `null` for non-streaming + // calls. + throw new InvariantError( + `Detected a 'stop_reason' value of 'null' during a non-streaming call.` + ) + } + + if (stopReason === 'end_turn' || stopReason === 'stop_sequence') { + return 'stop' + } else if (stopReason === 'max_tokens') { + return 'length' + } else if (stopReason === 'tool_use') { + return 'tool_calls' + } else { + return 'unknown' + } +} + +export const convertToolParams = ( + toolChoice: CompletionParams['tool_choice'], + tools: CompletionParams['tools'] +): { + toolChoice: MessageCreateParamsNonStreaming['tool_choice'] + tools: MessageCreateParamsNonStreaming['tools'] +} => { + if (tools === undefined || toolChoice === 'none') { + return { toolChoice: undefined, tools: undefined } + } + + const convertedTools: MessageCreateParamsNonStreaming['tools'] = tools.map( + (tool) => { + return { + name: tool.function.name, + description: tool.function.description, + input_schema: { type: 'object', ...tool.function.parameters }, + } + } + ) + + let convertedToolChoice: MessageCreateParamsNonStreaming['tool_choice'] + if (toolChoice === undefined || toolChoice === 'auto') { + convertedToolChoice = { type: 'auto' } + } else if (toolChoice === 'required') { + convertedToolChoice = { type: 'any' } + } else { + convertedToolChoice = { type: 'tool', name: toolChoice.function.name } + } + + return { toolChoice: convertedToolChoice, tools: convertedTools } +} + +const toFinishReasonStreaming = ( + stopReason: Message['stop_reason'] +): CompletionResponseChunk['choices'][0]['finish_reason'] => { + if (stopReason === null) { + return null + } else if (stopReason === 'end_turn' || stopReason === 'stop_sequence') { + return 'stop' + } else if (stopReason === 'max_tokens') { + return 'length' + } else if (stopReason === 'tool_use') { + return 'tool_calls' + } else { + return 'unknown' + } +} + +export const getDefaultMaxTokens = (model: string): number => { + if ( + model === 'claude-3-5-sonnet-v2@20241022' || + model === 'claude-3-7-sonnet@20250219' || + model === 'claude-sonnet-4@20250514' + ) { + return 8192 + } else { + // We default to 8192 when the model is not specifically handled here to avoid throwing errors + return 8192 + } +} + +export const convertMessages = async ( + messages: CompletionParams['messages'] +): Promise<{ + messages: MessageCreateParamsNonStreaming['messages'] + systemMessage: string | undefined +}> => { + const output: MessageCreateParamsNonStreaming['messages'] = [] + const clonedMessages = structuredClone(messages) + + // Pop the first element from the user-defined `messages` array if it begins with a 'system' + // message. The returned element will be used for Anthropic's `system` parameter. We only pop the + // system message if it's the first element in the array so that the order of the messages remains + // unchanged. + let systemMessage: string | undefined + if (clonedMessages.length > 0 && clonedMessages[0].role === 'system') { + systemMessage = convertMessageContentToString(clonedMessages[0].content) + clonedMessages.shift() + } + + // Anthropic requires that the first message in the array is from a 'user' role, so we inject a + // placeholder user message if the array doesn't already begin with a message from a 'user' role. + if ( + clonedMessages[0].role !== 'user' && + clonedMessages[0].role !== 'system' + ) { + clonedMessages.unshift({ + role: 'user', + content: 'Empty', + }) + } + + let previousRole: 'user' | 'assistant' = 'user' + let currentParams: Array< + TextBlockParam | ImageBlockParam | ToolUseBlockParam | ToolResultBlockParam + > = [] + for (const message of clonedMessages) { + // Anthropic doesn't support the `system` role in their `messages` array, so if the user + // defines system messages, we replace it with the `user` role and prepend 'System: ' to its + // content. We do this instead of putting every system message in Anthropic's `system` string + // parameter so that the order of the user-defined `messages` remains the same, even when the + // system messages are interspersed with messages from other roles. + const newRole = + message.role === 'user' || + message.role === 'system' || + message.role === 'tool' + ? 'user' + : 'assistant' + + if (previousRole !== newRole) { + output.push({ + role: previousRole, + content: currentParams, + }) + currentParams = [] + } + + if (message.role === 'tool') { + const toolResult: ToolResultBlockParam = { + tool_use_id: message.tool_call_id, + content: message.content, + type: 'tool_result', + } + currentParams.push(toolResult) + } else if (message.role === 'assistant') { + if (typeof message.content === 'string') { + currentParams.push({ + text: message.content, + type: 'text', + }) + } + + if (Array.isArray(message.tool_calls)) { + const convertedContent: Array = + message.tool_calls.map((toolCall) => { + return { + id: toolCall.id, + input: JSON.parse(toolCall.function.arguments), + name: toolCall.function.name, + type: 'tool_use', + } + }) + currentParams.push(...convertedContent) + } + } else if (typeof message.content === 'string') { + const text = + message.role === 'system' + ? `System: ${message.content}` + : message.content + currentParams.push({ + type: 'text', + text, + }) + } else if (Array.isArray(message.content)) { + const convertedContent: Array = + await Promise.all( + message.content.map(async (e) => { + if (e.type === 'text') { + const text = + message.role === 'system' ? `System: ${e.text}` : e.text + return { + type: 'text', + text, + } as TextBlockParam + } else { + const parsedImage = await fetchThenParseImage(e.image_url.url) + return { + type: 'image', + source: { + data: parsedImage.content, + media_type: parsedImage.mimeType, + type: 'base64', + }, + } as ImageBlockParam + } + }) + ) + currentParams.push(...convertedContent) + } + + previousRole = newRole + } + + if (currentParams.length > 0) { + output.push({ + role: previousRole, + content: currentParams, + }) + } + + return { messages: output, systemMessage } +} + +export const convertStopSequences = ( + stop?: CompletionParams['stop'] +): Array | undefined => { + if (stop === null || stop === undefined) { + return undefined + } else if (typeof stop === 'string') { + return [stop] + } else if (Array.isArray(stop) && stop.every((e) => typeof e === 'string')) { + return stop + } else { + throw new Error(`Unknown stop sequence: ${stop}`) + } +} + +const getCredentials = ( + serviceAccount?: string +): CredentialBody | Record => { + try { + const base64EncodedCredentials = + serviceAccount ?? process.env.VERTEX_SERVICE_ACCOUNT_B64 + + if (!base64EncodedCredentials) { + return {} as CredentialBody + } + return JSON.parse( + Buffer.from(base64EncodedCredentials, 'base64').toString('ascii') + ) as CredentialBody + } catch (e) { + return {} + } +} + +const getRegion = (region?: string) => { + return region ?? 'europe-west1' +} + +const getProjectId = (projectId?: string) => { + return projectId ?? process.env.GOOGLE_VERTEX_PROJECT_ID +} + +export class AnthropicVertexHandler extends BaseHandler { + validateInputs(body: ProviderCompletionParams<'anthropic-vertex'>): void { + super.validateInputs(body) + + let logImageDetailWarning: boolean = false + for (const message of body.messages) { + if (Array.isArray(message.content)) { + for (const e of message.content) { + if (e.type === 'image_url') { + if ( + e.image_url.detail !== undefined && + e.image_url.detail !== 'auto' + ) { + logImageDetailWarning = true + } + } + } + } + } + + if (logImageDetailWarning) { + consoleWarn( + `Anthropic does not support the 'detail' field for images. The default image quality will be used.` + ) + } + } + + async create( + body: ProviderCompletionParams<'anthropic-vertex'> + ): Promise { + this.validateInputs(body) + + const credentials = getCredentials(this.opts.vertex?.serviceAccount) + if ( + !credentials || + Object.keys(credentials).length === 0 || + !credentials.client_email || + !credentials.private_key + ) { + throw new InputError( + "No valid Vertex AI service account credentials detected. Please define a 'VERTEX_SERVICE_ACCOUNT_B64' environment variable or supply the credentials using the 'vertex.serviceAccount' parameter." + ) + } + + const googleAuth = new GoogleAuth({ + scopes: 'https://www.googleapis.com/auth/cloud-platform', + credentials, + }) + + const client = new AnthropicVertex({ + region: getRegion(), + projectId: getProjectId(this.opts.vertex?.projectId), + googleAuth, + }) + + const stream = typeof body.stream === 'boolean' ? body.stream : undefined + const maxTokens = body.max_tokens ?? getDefaultMaxTokens(body.model) + + const stopSequences = convertStopSequences(body.stop) + const topP = typeof body.top_p === 'number' ? body.top_p : undefined + const temperature = + typeof body.temperature === 'number' + ? // We divide by two because Anthropic's temperature range is 0 to 1, unlike OpenAI's, which is + // 0 to 2. + body.temperature / 2 + : undefined + const { messages, systemMessage } = await convertMessages(body.messages) + const { toolChoice, tools } = convertToolParams( + body.tool_choice, + body.tools + ) + + if (stream === true) { + const convertedBody: MessageCreateParamsStreaming = { + max_tokens: maxTokens, + messages, + model: body.model, + stop_sequences: stopSequences, + temperature, + top_p: topP, + stream, + system: systemMessage, + tools, + tool_choice: toolChoice, + } + const created = getTimestamp() + const response = client.messages.stream(convertedBody) + + return createCompletionResponseStreaming(response, created) + } else { + const convertedBody: MessageCreateParamsNonStreaming = { + max_tokens: maxTokens, + messages, + model: body.model, + stop_sequences: stopSequences, + temperature, + top_p: topP, + system: systemMessage, + tools, + tool_choice: toolChoice, + } + + const created = getTimestamp() + const response = await client.messages.create(convertedBody) + return createCompletionResponseNonStreaming( + response, + created, + body.tool_choice + ) + } + } +} diff --git a/src/handlers/bedrock.ts b/src/handlers/bedrock.ts index fc7db54..5e07132 100644 --- a/src/handlers/bedrock.ts +++ b/src/handlers/bedrock.ts @@ -523,7 +523,8 @@ async function* createCompletionResponseStreaming( { index: index - initialToolCallIndex, function: { - arguments: stream.contentBlockDelta.delta.toolUse.input || '{}', + arguments: + stream.contentBlockDelta.delta.toolUse.input || '{}', }, }, ], @@ -555,16 +556,19 @@ async function* createCompletionResponseStreaming( index: 0, finish_reason: finishReason, logprobs: null, - delta: finishReason === 'tool_calls' ? { - tool_calls: [ - { - index: 0, - function: { - arguments: '{}', - }, - }, - ], - } : delta, + delta: + finishReason === 'tool_calls' + ? { + tool_calls: [ + { + index: 0, + function: { + arguments: '{}', + }, + }, + ], + } + : delta, }, ], created, diff --git a/src/handlers/utils.ts b/src/handlers/utils.ts index 162afd1..cded724 100644 --- a/src/handlers/utils.ts +++ b/src/handlers/utils.ts @@ -6,6 +6,7 @@ import { LLMChatModel, LLMProvider } from '../chat/index.js' import { models } from '../models.js' import { ConfigOptions } from '../userTypes/index.js' import { AI21Handler } from './ai21.js' +import { AnthropicVertexHandler } from './anthropic-vertex.js' import { AnthropicHandler } from './anthropic.js' import { BaseHandler } from './base.js' import { BedrockHandler } from './bedrock.js' @@ -40,6 +41,16 @@ export const Handlers: Record any> = { models.anthropic.supportsN, models.anthropic.supportsStreaming ), + ['anthropic-vertex']: (opts: ConfigOptions) => + new AnthropicVertexHandler( + opts, + models['anthropic-vertex'].models, + models['anthropic-vertex'].supportsJSON, + models['anthropic-vertex'].supportsImages, + models['anthropic-vertex'].supportsToolCalls, + models['anthropic-vertex'].supportsN, + models['anthropic-vertex'].supportsStreaming + ), ['gemini']: (opts: ConfigOptions) => new GeminiHandler( opts, diff --git a/src/models.ts b/src/models.ts index 77df1c9..61fd12f 100644 --- a/src/models.ts +++ b/src/models.ts @@ -197,6 +197,33 @@ export const models = { supportsN: false, generateDocs: true, }, + 'anthropic-vertex': { + /* https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude */ + models: [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-3-7-sonnet@20250219', + 'claude-sonnet-4@20250514', + ] as const, + supportsCompletion: true, + supportsStreaming: [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-3-7-sonnet@20250219', + 'claude-sonnet-4@20250514', + ] as const, + supportsJSON: [] as const, + supportsImages: [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-3-7-sonnet@20250219', + 'claude-sonnet-4@20250514', + ] as const, + supportsToolCalls: [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-3-7-sonnet@20250219', + 'claude-sonnet-4@20250514', + ] as const, + supportsN: false, + generateDocs: true, + }, gemini: { models: [ 'gemini-2.0-flash-001', diff --git a/src/userTypes/index.ts b/src/userTypes/index.ts index 470d64f..bde57c0 100644 --- a/src/userTypes/index.ts +++ b/src/userTypes/index.ts @@ -16,6 +16,16 @@ export type ConfigOptions = Pick & { accessKeyId?: string secretAccessKey?: string } + /* @param vertex - The Vertex configuration object containing necessary credentials and settings + * @param vertex.region - The Google Cloud region where the Vertex AI service is deployed + * @param vertex.projectId - The Google Cloud project ID + * @param vertex.serviceAccount - Base64-encoded service account credentials for Google Cloud authentication + */ + vertex?: { + region?: string + projectId?: string + serviceAccount?: string + } } export type ChatCompletionChoice = Omit<