From 6b31a1b63ee1b78ab91349e1eaf031c3490623ad Mon Sep 17 00:00:00 2001 From: Ajay-Satish-01 Date: Tue, 8 Jul 2025 04:17:12 -0500 Subject: [PATCH 1/7] improvement: cohere v2 integration for chat, streaming --- src/providers/cohere/chatComplete.ts | 271 ++++++++++++++++----------- 1 file changed, 163 insertions(+), 108 deletions(-) diff --git a/src/providers/cohere/chatComplete.ts b/src/providers/cohere/chatComplete.ts index 7b41aad0d..19d51e438 100644 --- a/src/providers/cohere/chatComplete.ts +++ b/src/providers/cohere/chatComplete.ts @@ -13,61 +13,37 @@ import { CohereStreamState } from './types'; export const CohereChatCompleteConfig: ProviderConfig = { model: { param: 'model', - default: 'command', + default: 'command-r-plus', required: true, }, - messages: [ - { - param: 'message', - required: true, - transform: (params: Params) => { - const messages = params.messages || []; - const prompt = messages.at(-1); - if (!prompt) { - throw new Error('messages length should be at least of length 1'); - } + messages: { + param: 'messages', + required: true, + transform: (params: Params) => { + const messages = params.messages || []; + if (messages.length === 0) { + throw new Error('messages length should be at least of length 1'); + } + + return messages.map((message: Message) => { + let content: string = ''; - if (typeof prompt.content === 'string') { - return prompt.content; + if (typeof message.content === 'string') { + content = message.content; + } else if (Array.isArray(message.content)) { + content = message.content + .filter((c) => c.type === 'text') + .map((c) => c.text) + .join('\n'); } - return prompt.content - ?.filter((_msg) => _msg.type === 'text') - .reduce((acc, _msg) => acc + _msg.text + '\n', ''); - }, + return { + role: message.role === 'assistant' ? 'assistant' : message.role, + content: content, + }; + }); }, - { - param: 'chat_history', - required: false, - transform: (params: Params) => { - const messages = params.messages || []; - const messagesWithoutLastMessage = messages.slice( - 0, - messages.length - 1 - ); - // generate history and forward it to model - const history: { message?: string; role: string }[] = - messagesWithoutLastMessage.map((message) => { - const _message: { role: any; message: string } = { - role: message.role === 'assistant' ? 'chatbot' : message.role, - message: '', - }; - - if (typeof message.content === 'string') { - _message['message'] = message.content; - } else if (Array.isArray(message.content)) { - _message['message'] = (message.content ?? []) - .filter((c) => Boolean(c.text)) - .map((content) => content.text) - .join('\n'); - } - - return _message; - }); - return history; - }, - }, - ], + }, max_tokens: { param: 'max_tokens', default: 20, @@ -108,50 +84,62 @@ export const CohereChatCompleteConfig: ProviderConfig = { max: 1, }, stop: { - param: 'end_sequences', + param: 'stop_sequences', }, stream: { param: 'stream', default: false, }, + seed: { + param: 'seed', + }, + logprobs: { + param: 'logprobs', + default: false, + }, }; -interface CohereCompleteResponse { - text: string; - generation_id: string; +interface CohereV2CompleteResponse { + id: string; finish_reason: | 'COMPLETE' | 'STOP_SEQUENCE' - | 'ERROR' - | 'ERROR_TOXIC' - | 'ERROR_LIMIT' - | 'USER_CANCEL' - | 'MAX_TOKENS'; - meta: { - api_version: { - version: string; - }; - billed_units: { - input_tokens: number; - output_tokens: number; - }; + | 'MAX_TOKENS' + | 'TOOL_CALL' + | 'ERROR'; + message: { + role: 'assistant'; + content: Array<{ + type: 'text'; + text: string; + }>; + }; + usage: { + input_tokens: number; + output_tokens: number; }; - chat_history?: { - role: 'CHATBOT' | 'SYSTEM' | 'TOOL' | 'USER'; - message: string; - }[]; + logprobs?: Array<{ + token: string; + logprob: number; + }> | null; +} + +interface CohereV2ErrorResponse { message?: string; + error?: string; status?: number; } export const CohereChatCompleteResponseTransform: ( - response: CohereCompleteResponse, + response: CohereV2CompleteResponse | CohereV2ErrorResponse, responseStatus: number ) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { if (responseStatus !== 200) { + const errorResponse = response as CohereV2ErrorResponse; return generateErrorResponse( { - message: response.message || '', + message: + errorResponse.message || errorResponse.error || 'Unknown error', type: null, param: null, code: null, @@ -160,39 +148,70 @@ export const CohereChatCompleteResponseTransform: ( ); } + const successResponse = response as CohereV2CompleteResponse; + + const textContent = successResponse.message.content + .filter((c: { type: string; text: string }) => c.type === 'text') + .map((c: { type: string; text: string }) => c.text) + .join(''); + return { - id: response.generation_id, + id: successResponse.id, object: 'chat.completion', created: Math.floor(Date.now() / 1000), model: 'Unknown', provider: COHERE, choices: [ { - message: { role: 'assistant', content: response.text }, + message: { role: 'assistant', content: textContent }, index: 0, - finish_reason: response.finish_reason, + finish_reason: successResponse.finish_reason, }, ], usage: { - completion_tokens: response.meta.billed_units.output_tokens, - prompt_tokens: response.meta.billed_units.input_tokens, + completion_tokens: successResponse.usage.output_tokens, + prompt_tokens: successResponse.usage.input_tokens, total_tokens: Number( - response.meta.billed_units.output_tokens + - response.meta.billed_units.input_tokens + successResponse.usage.output_tokens + successResponse.usage.input_tokens ), }, }; }; -export type CohereStreamChunk = - | { event_type: 'stream-start'; generation_id: string } - | { event_type: 'text-generation'; text: string } +export type CohereV2StreamChunk = | { - event_type: 'stream-end'; - response_id: string; - response: { - finish_reason: CohereCompleteResponse['finish_reason']; - meta: CohereCompleteResponse['meta']; + type: 'message-start'; + message: { + id: string; + role: 'assistant'; + content: Array; + }; + } + | { + type: 'content-start'; + index: number; + content_block: { + type: 'text'; + text: string; + }; + } + | { + type: 'content-delta'; + index: number; + delta: { + text: string; + }; + } + | { + type: 'content-end'; + index: number; + } + | { + type: 'message-end'; + message: { + id: string; + finish_reason: CohereV2CompleteResponse['finish_reason']; + usage: CohereV2CompleteResponse['usage']; }; }; @@ -212,40 +231,76 @@ export const CohereChatCompleteStreamChunkTransform: ( let chunk = responseChunk.trim(); chunk = chunk.replace(/^data: /, ''); chunk = chunk.trim(); - const parsedChunk: CohereStreamChunk = JSON.parse(chunk); - if (parsedChunk.event_type === 'stream-start') { - streamState.generation_id = parsedChunk.generation_id; + + if (!chunk || chunk === '[DONE]') { + return `data: [DONE]\n\n`; } - return ( - `data: ${JSON.stringify({ - id: streamState?.generation_id ?? fallbackId, + try { + const parsedChunk: CohereV2StreamChunk = JSON.parse(chunk); + + if (parsedChunk.type === 'message-start') { + streamState.generation_id = parsedChunk.message.id; + } + + const messageId = streamState?.generation_id ?? fallbackId; + let deltaContent = ''; + let finishReason = null; + let usage = null; + + if (parsedChunk.type === 'content-delta') { + deltaContent = parsedChunk.delta.text; + } else if (parsedChunk.type === 'message-end') { + finishReason = parsedChunk.message.finish_reason; + usage = { + completion_tokens: parsedChunk.message.usage.output_tokens, + prompt_tokens: parsedChunk.message.usage.input_tokens, + total_tokens: Number( + parsedChunk.message.usage.output_tokens + + parsedChunk.message.usage.input_tokens + ), + }; + } + + return ( + `data: ${JSON.stringify({ + id: messageId, + object: 'chat.completion.chunk', + created: Math.floor(Date.now() / 1000), + model: gatewayRequest.model || '', + provider: COHERE, + ...(usage && { usage }), + choices: [ + { + index: 0, + delta: { + content: deltaContent, + role: 'assistant', + }, + logprobs: null, + finish_reason: finishReason, + }, + ], + })}` + '\n\n' + ); + } catch (error) { + return `data: ${JSON.stringify({ + id: fallbackId, object: 'chat.completion.chunk', created: Math.floor(Date.now() / 1000), model: gatewayRequest.model || '', provider: COHERE, - ...(parsedChunk.event_type === 'stream-end' && { - usage: { - completion_tokens: - parsedChunk.response.meta.billed_units.output_tokens, - prompt_tokens: parsedChunk.response.meta.billed_units.input_tokens, - total_tokens: Number( - parsedChunk.response.meta.billed_units.output_tokens + - parsedChunk.response.meta.billed_units.input_tokens - ), - }, - }), choices: [ { index: 0, delta: { - content: (parsedChunk as any)?.text ?? '', + content: '', role: 'assistant', }, logprobs: null, - finish_reason: (parsedChunk as any).finish_reason ?? null, + finish_reason: null, }, ], - })}` + '\n\n' - ); + })}\n\n`; + } }; From ca9e24e7ce66714bd0c8452e693414f311479a66 Mon Sep 17 00:00:00 2001 From: Ajay Satish <71289526+Ajay-Satish-01@users.noreply.github.com> Date: Tue, 8 Jul 2025 04:26:37 -0500 Subject: [PATCH 2/7] chore: commit changes from matter-code Co-authored-by: matter-code-review[bot] <150888575+matter-code-review[bot]@users.noreply.github.com> --- src/providers/cohere/chatComplete.ts | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/providers/cohere/chatComplete.ts b/src/providers/cohere/chatComplete.ts index 19d51e438..cf4faeeae 100644 --- a/src/providers/cohere/chatComplete.ts +++ b/src/providers/cohere/chatComplete.ts @@ -31,14 +31,23 @@ export const CohereChatCompleteConfig: ProviderConfig = { if (typeof message.content === 'string') { content = message.content; } else if (Array.isArray(message.content)) { - content = message.content + const textContents = message.content .filter((c) => c.type === 'text') - .map((c) => c.text) - .join('\n'); + .map((c) => c.text); + + if (textContents.length === 0) { + throw new Error('No text content found in message content array'); + } + + content = textContents.join('\ +'); } return { - role: message.role === 'assistant' ? 'assistant' : message.role, + role: message.role === 'assistant' ? 'assistant' : + message.role === 'user' ? 'user' : + message.role === 'system' ? 'system' : + message.role, content: content, }; }); @@ -151,8 +160,8 @@ export const CohereChatCompleteResponseTransform: ( const successResponse = response as CohereV2CompleteResponse; const textContent = successResponse.message.content - .filter((c: { type: string; text: string }) => c.type === 'text') - .map((c: { type: string; text: string }) => c.text) + .filter((c) => c.type === 'text') + .map((c) => c.text) .join(''); return { @@ -284,12 +293,14 @@ export const CohereChatCompleteStreamChunkTransform: ( })}` + '\n\n' ); } catch (error) { + console.error('Error processing Cohere stream chunk:', error); return `data: ${JSON.stringify({ id: fallbackId, object: 'chat.completion.chunk', created: Math.floor(Date.now() / 1000), model: gatewayRequest.model || '', provider: COHERE, + error: error instanceof Error ? error.message : String(error), choices: [ { index: 0, From e1a97abb1609eb70321415a91e009a8a6fabdc0e Mon Sep 17 00:00:00 2001 From: Ajay-Satish-01 Date: Tue, 8 Jul 2025 04:31:46 -0500 Subject: [PATCH 3/7] chore: formatting fix --- plugins/aporia/validateProject.ts | 2 +- src/handlers/handlerUtils.ts | 2 +- src/providers/cohere/chatComplete.ts | 23 +++++++++++++++-------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/plugins/aporia/validateProject.ts b/plugins/aporia/validateProject.ts index 65aef7936..edc93b81a 100644 --- a/plugins/aporia/validateProject.ts +++ b/plugins/aporia/validateProject.ts @@ -52,7 +52,7 @@ export const handler: PluginHandler = async ( } else { const _content = message.content?.reduce( (value, item) => - value + (item.type === 'text' ? `${item.text}\n` ?? '' : ''), + value + (item.type === 'text' ? (`${item.text}\n` ?? '') : ''), '' ); return { ...message, content: _content }; diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index 3029e748d..def278c71 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -461,7 +461,7 @@ export async function tryPost( providerOption.retry = { attempts: providerOption.retry?.attempts ?? 0, onStatusCodes: providerOption.retry?.attempts - ? providerOption.retry?.onStatusCodes ?? RETRY_STATUS_CODES + ? (providerOption.retry?.onStatusCodes ?? RETRY_STATUS_CODES) : [], useRetryAfterHeader: providerOption?.retry?.useRetryAfterHeader, }; diff --git a/src/providers/cohere/chatComplete.ts b/src/providers/cohere/chatComplete.ts index cf4faeeae..0ff3a20d0 100644 --- a/src/providers/cohere/chatComplete.ts +++ b/src/providers/cohere/chatComplete.ts @@ -34,20 +34,27 @@ export const CohereChatCompleteConfig: ProviderConfig = { const textContents = message.content .filter((c) => c.type === 'text') .map((c) => c.text); - + if (textContents.length === 0) { throw new Error('No text content found in message content array'); } - - content = textContents.join('\ -'); + + content = + textContents.join( + '\ +' + ); } return { - role: message.role === 'assistant' ? 'assistant' : - message.role === 'user' ? 'user' : - message.role === 'system' ? 'system' : - message.role, + role: + message.role === 'assistant' + ? 'assistant' + : message.role === 'user' + ? 'user' + : message.role === 'system' + ? 'system' + : message.role, content: content, }; }); From b16dc7a848c5c4845da78db909cb1df52feed84b Mon Sep 17 00:00:00 2001 From: Ajay-Satish-01 Date: Tue, 8 Jul 2025 04:35:09 -0500 Subject: [PATCH 4/7] chore: formtting fix --- plugins/aporia/validateProject.ts | 2 +- src/handlers/handlerUtils.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/aporia/validateProject.ts b/plugins/aporia/validateProject.ts index edc93b81a..65aef7936 100644 --- a/plugins/aporia/validateProject.ts +++ b/plugins/aporia/validateProject.ts @@ -52,7 +52,7 @@ export const handler: PluginHandler = async ( } else { const _content = message.content?.reduce( (value, item) => - value + (item.type === 'text' ? (`${item.text}\n` ?? '') : ''), + value + (item.type === 'text' ? `${item.text}\n` ?? '' : ''), '' ); return { ...message, content: _content }; diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index def278c71..3029e748d 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -461,7 +461,7 @@ export async function tryPost( providerOption.retry = { attempts: providerOption.retry?.attempts ?? 0, onStatusCodes: providerOption.retry?.attempts - ? (providerOption.retry?.onStatusCodes ?? RETRY_STATUS_CODES) + ? providerOption.retry?.onStatusCodes ?? RETRY_STATUS_CODES : [], useRetryAfterHeader: providerOption?.retry?.useRetryAfterHeader, }; From ed9c1e77d6abe5c2559f705a6ef52e0f8b52efcb Mon Sep 17 00:00:00 2001 From: Ajay-Satish-01 Date: Wed, 9 Jul 2025 14:49:52 -0500 Subject: [PATCH 5/7] fix(api): v2 version --- src/providers/cohere/api.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/providers/cohere/api.ts b/src/providers/cohere/api.ts index 57d7a1b22..7a84cc8c3 100644 --- a/src/providers/cohere/api.ts +++ b/src/providers/cohere/api.ts @@ -1,7 +1,7 @@ import { ProviderAPIConfig } from '../types'; const CohereAPIConfig: ProviderAPIConfig = { - getBaseURL: () => 'https://api.cohere.ai/v1', + getBaseURL: () => 'https://api.cohere.ai/v2', headers: ({ providerOptions, fn }) => { const headers: Record = { Authorization: `Bearer ${providerOptions.apiKey}`, From 1ef109e39d546e8b8a496b401efd20e7c6cca932 Mon Sep 17 00:00:00 2001 From: Ajay-Satish-01 Date: Wed, 9 Jul 2025 14:50:29 -0500 Subject: [PATCH 6/7] chore(cohere): error handling --- src/providers/cohere/chatComplete.ts | 96 ++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 18 deletions(-) diff --git a/src/providers/cohere/chatComplete.ts b/src/providers/cohere/chatComplete.ts index 0ff3a20d0..5002cff1d 100644 --- a/src/providers/cohere/chatComplete.ts +++ b/src/providers/cohere/chatComplete.ts @@ -39,11 +39,7 @@ export const CohereChatCompleteConfig: ProviderConfig = { throw new Error('No text content found in message content array'); } - content = - textContents.join( - '\ -' - ); + content = textContents.join('\n'); } return { @@ -65,11 +61,6 @@ export const CohereChatCompleteConfig: ProviderConfig = { default: 20, min: 1, }, - max_completion_tokens: { - param: 'max_tokens', - default: 20, - min: 1, - }, temperature: { param: 'temperature', default: 0.75, @@ -113,6 +104,30 @@ export const CohereChatCompleteConfig: ProviderConfig = { param: 'logprobs', default: false, }, + preamble: { + param: 'preamble', + }, + connectors: { + param: 'connectors', + }, + search_queries_only: { + param: 'search_queries_only', + default: false, + }, + citation_quality: { + param: 'citation_quality', + default: 'accurate', + }, + prompt_truncation: { + param: 'prompt_truncation', + default: 'AUTO', + }, + tools: { + param: 'tools', + }, + tool_results: { + param: 'tool_results', + }, }; interface CohereV2CompleteResponse { @@ -126,8 +141,12 @@ interface CohereV2CompleteResponse { message: { role: 'assistant'; content: Array<{ - type: 'text'; - text: string; + type: 'text' | 'tool_calls'; + text?: string; + tool_calls?: Array<{ + name: string; + parameters: Record; + }>; }>; }; usage: { @@ -138,6 +157,21 @@ interface CohereV2CompleteResponse { token: string; logprob: number; }> | null; + search_results?: Array<{ + search_query: string; + results: Array<{ + id: string; + title: string; + url: string; + text: string; + }>; + }>; + citations?: Array<{ + start: number; + end: number; + text: string; + document_ids: string[]; + }>; } interface CohereV2ErrorResponse { @@ -171,6 +205,23 @@ export const CohereChatCompleteResponseTransform: ( .map((c) => c.text) .join(''); + const toolCalls = successResponse.message.content + .filter((c) => c.type === 'tool_calls') + .flatMap((c) => c.tool_calls || []); + + const message: any = { role: 'assistant', content: textContent }; + + if (toolCalls.length > 0) { + message.tool_calls = toolCalls.map((toolCall, index) => ({ + id: `call_${index}`, + type: 'function', + function: { + name: toolCall.name, + arguments: JSON.stringify(toolCall.parameters), + }, + })); + } + return { id: successResponse.id, object: 'chat.completion', @@ -179,7 +230,7 @@ export const CohereChatCompleteResponseTransform: ( provider: COHERE, choices: [ { - message: { role: 'assistant', content: textContent }, + message, index: 0, finish_reason: successResponse.finish_reason, }, @@ -207,15 +258,23 @@ export type CohereV2StreamChunk = type: 'content-start'; index: number; content_block: { - type: 'text'; - text: string; + type: 'text' | 'tool_calls'; + text?: string; + tool_calls?: Array<{ + name: string; + parameters: Record; + }>; }; } | { type: 'content-delta'; index: number; delta: { - text: string; + text?: string; + tool_calls?: Array<{ + name: string; + parameters: Record; + }>; }; } | { @@ -228,6 +287,8 @@ export type CohereV2StreamChunk = id: string; finish_reason: CohereV2CompleteResponse['finish_reason']; usage: CohereV2CompleteResponse['usage']; + search_results?: CohereV2CompleteResponse['search_results']; + citations?: CohereV2CompleteResponse['citations']; }; }; @@ -265,7 +326,7 @@ export const CohereChatCompleteStreamChunkTransform: ( let usage = null; if (parsedChunk.type === 'content-delta') { - deltaContent = parsedChunk.delta.text; + deltaContent = parsedChunk.delta.text || ''; } else if (parsedChunk.type === 'message-end') { finishReason = parsedChunk.message.finish_reason; usage = { @@ -300,7 +361,6 @@ export const CohereChatCompleteStreamChunkTransform: ( })}` + '\n\n' ); } catch (error) { - console.error('Error processing Cohere stream chunk:', error); return `data: ${JSON.stringify({ id: fallbackId, object: 'chat.completion.chunk', From 9c77bc26fd906f5bfc1bd1d2dc0a10b1d14b92bb Mon Sep 17 00:00:00 2001 From: Ajay-Satish-01 Date: Fri, 11 Jul 2025 22:35:45 -0500 Subject: [PATCH 7/7] feat(cohere): add more v2 chat completions --- src/providers/cohere/chatComplete.ts | 440 ++++++++++++++++++--------- src/providers/cohere/types.ts | 2 + 2 files changed, 303 insertions(+), 139 deletions(-) diff --git a/src/providers/cohere/chatComplete.ts b/src/providers/cohere/chatComplete.ts index 5002cff1d..ffed4b6e0 100644 --- a/src/providers/cohere/chatComplete.ts +++ b/src/providers/cohere/chatComplete.ts @@ -13,7 +13,6 @@ import { CohereStreamState } from './types'; export const CohereChatCompleteConfig: ProviderConfig = { model: { param: 'model', - default: 'command-r-plus', required: true, }, messages: { @@ -22,114 +21,214 @@ export const CohereChatCompleteConfig: ProviderConfig = { transform: (params: Params) => { const messages = params.messages || []; if (messages.length === 0) { - throw new Error('messages length should be at least of length 1'); + throw new Error('At least one message is required'); } return messages.map((message: Message) => { - let content: string = ''; + if (message.role === 'system') { + return { + role: 'system', + content: typeof message.content === 'string' ? message.content : '', + }; + } + + if (message.role === 'tool') { + return { + role: 'tool', + tool_call_id: message.tool_call_id, + content: message.content, + }; + } + + let content: string | Array = ''; if (typeof message.content === 'string') { content = message.content; } else if (Array.isArray(message.content)) { - const textContents = message.content - .filter((c) => c.type === 'text') - .map((c) => c.text); + const cohereContent: Array = []; - if (textContents.length === 0) { - throw new Error('No text content found in message content array'); + for (const item of message.content) { + if (item.type === 'text') { + cohereContent.push({ + type: 'text', + text: item.text, + }); + } else if (item.type === 'image_url') { + cohereContent.push({ + type: 'image', + source: { + type: 'url', + url: item.image_url?.url, + }, + }); + } } - content = textContents.join('\n'); + content = cohereContent.length > 0 ? cohereContent : ''; } - return { - role: - message.role === 'assistant' - ? 'assistant' - : message.role === 'user' - ? 'user' - : message.role === 'system' - ? 'system' - : message.role, + const cohereMessage: any = { + role: message.role === 'assistant' ? 'assistant' : 'user', content: content, }; + + if (message.role === 'assistant' && message.tool_calls) { + cohereMessage.tool_calls = message.tool_calls.map( + (toolCall: any) => ({ + id: toolCall.id, + type: toolCall.type, + function: { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }, + }) + ); + } + + if (message.role === 'assistant' && (message as any).tool_plan) { + cohereMessage.tool_plan = (message as any).tool_plan; + } + + return cohereMessage; }); }, }, max_tokens: { param: 'max_tokens', - default: 20, min: 1, }, temperature: { param: 'temperature', - default: 0.75, + default: 0.3, min: 0, - max: 5, }, - top_p: { - param: 'p', - default: 0.75, - min: 0, - max: 1, + seed: { + param: 'seed', + }, + stop: { + param: 'stop_sequences', }, top_k: { param: 'k', default: 0, + min: 0, max: 500, }, + top_p: { + param: 'p', + default: 0.75, + min: 0.01, + max: 0.99, + }, frequency_penalty: { param: 'frequency_penalty', - default: 0, - min: 0, - max: 1, + default: 0.0, + min: 0.0, + max: 1.0, }, presence_penalty: { param: 'presence_penalty', - default: 0, - min: 0, - max: 1, + default: 0.0, + min: 0.0, + max: 1.0, }, - stop: { - param: 'stop_sequences', + logprobs: { + param: 'logprobs', + default: false, }, stream: { param: 'stream', default: false, }, - seed: { - param: 'seed', - }, - logprobs: { - param: 'logprobs', - default: false, - }, - preamble: { - param: 'preamble', + tools: { + param: 'tools', + transform: (params: Params) => { + if (!params.tools) return undefined; + + return params.tools.map((tool) => ({ + type: 'function', + function: { + name: tool.function.name, + description: tool.function.description, + parameters: tool.function.parameters, + }, + })); + }, }, - connectors: { - param: 'connectors', + tool_choice: { + param: 'tool_choice', + transform: (params: Params) => { + const toolChoice = params.tool_choice; + if (!toolChoice) return undefined; + + if (toolChoice === 'none') return 'NONE'; + if (toolChoice === 'required') return 'REQUIRED'; + if (toolChoice === 'auto') return undefined; + + return toolChoice; + }, }, - search_queries_only: { - param: 'search_queries_only', - default: false, + strict_tools: { + param: 'strict_tools', }, - citation_quality: { - param: 'citation_quality', - default: 'accurate', + documents: { + param: 'documents', }, - prompt_truncation: { - param: 'prompt_truncation', - default: 'AUTO', + citation_options: { + param: 'citation_options', }, - tools: { - param: 'tools', + response_format: { + param: 'response_format', }, - tool_results: { - param: 'tool_results', + safety_mode: { + param: 'safety_mode', + default: 'CONTEXTUAL', }, }; +interface CohereV2Usage { + billed_units: { + input_tokens: number; + output_tokens: number; + }; + tokens: { + input_tokens: number; + output_tokens: number; + }; +} + +interface CohereV2Citation { + start: number; + end: number; + text: string; + sources: Array<{ + type: string; + id: string; + document?: any; + tool_output?: any; + }>; +} + +interface CohereV2ToolCall { + id: string; + type: 'function'; + function: { + name: string; + arguments: string; + }; +} + +interface CohereV2Message { + role: 'assistant'; + content: Array<{ + type: 'text'; + text: string; + }>; + tool_calls?: CohereV2ToolCall[]; + tool_plan?: string; + citations?: CohereV2Citation[]; +} + interface CohereV2CompleteResponse { id: string; finish_reason: @@ -138,46 +237,18 @@ interface CohereV2CompleteResponse { | 'MAX_TOKENS' | 'TOOL_CALL' | 'ERROR'; - message: { - role: 'assistant'; - content: Array<{ - type: 'text' | 'tool_calls'; - text?: string; - tool_calls?: Array<{ - name: string; - parameters: Record; - }>; - }>; - }; - usage: { - input_tokens: number; - output_tokens: number; - }; + message: CohereV2Message; + usage?: CohereV2Usage; logprobs?: Array<{ token: string; logprob: number; - }> | null; - search_results?: Array<{ - search_query: string; - results: Array<{ - id: string; - title: string; - url: string; - text: string; - }>; - }>; - citations?: Array<{ - start: number; - end: number; - text: string; - document_ids: string[]; }>; } interface CohereV2ErrorResponse { message?: string; error?: string; - status?: number; + detail?: string; } export const CohereChatCompleteResponseTransform: ( @@ -189,7 +260,10 @@ export const CohereChatCompleteResponseTransform: ( return generateErrorResponse( { message: - errorResponse.message || errorResponse.error || 'Unknown error', + errorResponse.message || + errorResponse.error || + errorResponse.detail || + 'Unknown error', type: null, param: null, code: null, @@ -205,43 +279,54 @@ export const CohereChatCompleteResponseTransform: ( .map((c) => c.text) .join(''); - const toolCalls = successResponse.message.content - .filter((c) => c.type === 'tool_calls') - .flatMap((c) => c.tool_calls || []); - - const message: any = { role: 'assistant', content: textContent }; + const message: any = { + role: 'assistant', + content: textContent, + }; - if (toolCalls.length > 0) { - message.tool_calls = toolCalls.map((toolCall, index) => ({ - id: `call_${index}`, - type: 'function', + if ( + successResponse.message.tool_calls && + successResponse.message.tool_calls.length > 0 + ) { + message.tool_calls = successResponse.message.tool_calls.map((toolCall) => ({ + id: toolCall.id, + type: toolCall.type, function: { - name: toolCall.name, - arguments: JSON.stringify(toolCall.parameters), + name: toolCall.function.name, + arguments: toolCall.function.arguments, }, })); } + let finishReason: string = successResponse.finish_reason; + if (finishReason === 'COMPLETE') finishReason = 'stop'; + else if (finishReason === 'MAX_TOKENS') finishReason = 'length'; + else if (finishReason === 'TOOL_CALL') finishReason = 'tool_calls'; + else if (finishReason === 'STOP_SEQUENCE') finishReason = 'stop'; + return { id: successResponse.id, object: 'chat.completion', created: Math.floor(Date.now() / 1000), - model: 'Unknown', + model: 'command-r-plus', // Default model name provider: COHERE, choices: [ { message, index: 0, - finish_reason: successResponse.finish_reason, + finish_reason: finishReason, }, ], - usage: { - completion_tokens: successResponse.usage.output_tokens, - prompt_tokens: successResponse.usage.input_tokens, - total_tokens: Number( - successResponse.usage.output_tokens + successResponse.usage.input_tokens - ), - }, + usage: successResponse.usage + ? { + completion_tokens: + successResponse.usage.billed_units?.output_tokens || 0, + prompt_tokens: successResponse.usage.billed_units?.input_tokens || 0, + total_tokens: + (successResponse.usage.billed_units?.output_tokens || 0) + + (successResponse.usage.billed_units?.input_tokens || 0), + } + : undefined, }; }; @@ -252,18 +337,16 @@ export type CohereV2StreamChunk = id: string; role: 'assistant'; content: Array; + tool_calls?: Array; + tool_plan?: string; }; } | { type: 'content-start'; index: number; content_block: { - type: 'text' | 'tool_calls'; - text?: string; - tool_calls?: Array<{ - name: string; - parameters: Record; - }>; + type: 'text'; + text: string; }; } | { @@ -271,24 +354,58 @@ export type CohereV2StreamChunk = index: number; delta: { text?: string; - tool_calls?: Array<{ - name: string; - parameters: Record; - }>; }; } | { type: 'content-end'; index: number; } + | { + type: 'tool-plan-delta'; + delta: { + tool_plan?: string; + }; + } + | { + type: 'tool-call-start'; + index: number; + tool_call: { + id: string; + type: 'function'; + function: { + name: string; + arguments: string; + }; + }; + } + | { + type: 'tool-call-delta'; + index: number; + delta: { + function?: { + arguments?: string; + }; + }; + } + | { + type: 'tool-call-end'; + index: number; + } + | { + type: 'citation-start'; + index: number; + citation: CohereV2Citation; + } + | { + type: 'citation-end'; + index: number; + } | { type: 'message-end'; message: { id: string; finish_reason: CohereV2CompleteResponse['finish_reason']; - usage: CohereV2CompleteResponse['usage']; - search_results?: CohereV2CompleteResponse['search_results']; - citations?: CohereV2CompleteResponse['citations']; + usage?: CohereV2Usage; }; }; @@ -301,7 +418,7 @@ export const CohereChatCompleteStreamChunkTransform: ( ) => string = ( responseChunk, fallbackId, - streamState = { generation_id: '' }, + streamState = { generation_id: '', tool_calls: {}, current_tool_call: null }, _strictOpenAiCompliance, gatewayRequest ) => { @@ -318,25 +435,73 @@ export const CohereChatCompleteStreamChunkTransform: ( if (parsedChunk.type === 'message-start') { streamState.generation_id = parsedChunk.message.id; + streamState.tool_calls = {}; + streamState.current_tool_call = null; } const messageId = streamState?.generation_id ?? fallbackId; let deltaContent = ''; let finishReason = null; let usage = null; + let toolCalls = null; if (parsedChunk.type === 'content-delta') { deltaContent = parsedChunk.delta.text || ''; - } else if (parsedChunk.type === 'message-end') { - finishReason = parsedChunk.message.finish_reason; - usage = { - completion_tokens: parsedChunk.message.usage.output_tokens, - prompt_tokens: parsedChunk.message.usage.input_tokens, - total_tokens: Number( - parsedChunk.message.usage.output_tokens + - parsedChunk.message.usage.input_tokens - ), + } else if (parsedChunk.type === 'tool-call-start') { + streamState.current_tool_call = { + id: parsedChunk.tool_call.id, + type: 'function', + function: { + name: parsedChunk.tool_call.function.name, + arguments: '', + }, }; + streamState.tool_calls[parsedChunk.index] = streamState.current_tool_call; + + toolCalls = [streamState.current_tool_call]; + } else if (parsedChunk.type === 'tool-call-delta') { + if ( + streamState.current_tool_call && + parsedChunk.delta.function?.arguments + ) { + streamState.current_tool_call.function.arguments += + parsedChunk.delta.function.arguments; + toolCalls = [streamState.current_tool_call]; + } + } else if (parsedChunk.type === 'message-end') { + const cohereFinishReason = parsedChunk.message.finish_reason; + let mappedFinishReason: string; + if (cohereFinishReason === 'COMPLETE') mappedFinishReason = 'stop'; + else if (cohereFinishReason === 'MAX_TOKENS') + mappedFinishReason = 'length'; + else if (cohereFinishReason === 'TOOL_CALL') + mappedFinishReason = 'tool_calls'; + else if (cohereFinishReason === 'STOP_SEQUENCE') + mappedFinishReason = 'stop'; + else mappedFinishReason = cohereFinishReason; + + finishReason = mappedFinishReason; + + if (parsedChunk.message.usage) { + usage = { + completion_tokens: + parsedChunk.message.usage.billed_units?.output_tokens || 0, + prompt_tokens: + parsedChunk.message.usage.billed_units?.input_tokens || 0, + total_tokens: + (parsedChunk.message.usage.billed_units?.output_tokens || 0) + + (parsedChunk.message.usage.billed_units?.input_tokens || 0), + }; + } + } + + const delta: any = { + content: deltaContent, + role: 'assistant', + }; + + if (toolCalls) { + delta.tool_calls = toolCalls; } return ( @@ -344,16 +509,13 @@ export const CohereChatCompleteStreamChunkTransform: ( id: messageId, object: 'chat.completion.chunk', created: Math.floor(Date.now() / 1000), - model: gatewayRequest.model || '', + model: gatewayRequest.model || 'command-r-plus', provider: COHERE, ...(usage && { usage }), choices: [ { index: 0, - delta: { - content: deltaContent, - role: 'assistant', - }, + delta, logprobs: null, finish_reason: finishReason, }, @@ -365,7 +527,7 @@ export const CohereChatCompleteStreamChunkTransform: ( id: fallbackId, object: 'chat.completion.chunk', created: Math.floor(Date.now() / 1000), - model: gatewayRequest.model || '', + model: gatewayRequest.model || 'command-r-plus', provider: COHERE, error: error instanceof Error ? error.message : String(error), choices: [ diff --git a/src/providers/cohere/types.ts b/src/providers/cohere/types.ts index 7d295aa57..e08f7b9e6 100644 --- a/src/providers/cohere/types.ts +++ b/src/providers/cohere/types.ts @@ -1,5 +1,7 @@ export type CohereStreamState = { generation_id: string; + tool_calls: Record; + current_tool_call: any; }; export interface CohereErrorResponse {