Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 181 additions & 108 deletions src/providers/cohere/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,61 +13,53 @@ import { CohereStreamState } from './types';
export const CohereChatCompleteConfig: ProviderConfig = {
model: {
param: 'model',
default: 'command',
default: 'command-r-plus',
required: true,
},
messages: [
{
param: 'message',
required: true,
transform: (params: Params) => {
const messages = params.messages || [];
const prompt = messages.at(-1);
if (!prompt) {
throw new Error('messages length should be at least of length 1');
}
messages: {
param: 'messages',
required: true,
transform: (params: Params) => {
const messages = params.messages || [];
if (messages.length === 0) {
throw new Error('messages length should be at least of length 1');
}

return messages.map((message: Message) => {
let content: string = '';

if (typeof message.content === 'string') {
content = message.content;
} else if (Array.isArray(message.content)) {
const textContents = message.content
.filter((c) => c.type === 'text')
.map((c) => c.text);

if (textContents.length === 0) {
throw new Error('No text content found in message content array');
}

if (typeof prompt.content === 'string') {
return prompt.content;
content =
textContents.join(
'\
'
);
}

return prompt.content
?.filter((_msg) => _msg.type === 'text')
.reduce((acc, _msg) => acc + _msg.text + '\n', '');
},
},
{
param: 'chat_history',
required: false,
transform: (params: Params) => {
const messages = params.messages || [];
const messagesWithoutLastMessage = messages.slice(
0,
messages.length - 1
);
// generate history and forward it to model
const history: { message?: string; role: string }[] =
messagesWithoutLastMessage.map((message) => {
const _message: { role: any; message: string } = {
role: message.role === 'assistant' ? 'chatbot' : message.role,
message: '',
};

if (typeof message.content === 'string') {
_message['message'] = message.content;
} else if (Array.isArray(message.content)) {
_message['message'] = (message.content ?? [])
.filter((c) => Boolean(c.text))
.map((content) => content.text)
.join('\n');
}

return _message;
});
return history;
},
return {
role:
message.role === 'assistant'
? 'assistant'
: message.role === 'user'
? 'user'
: message.role === 'system'
? 'system'
: message.role,
content: content,
};
});
},
],
},
max_tokens: {
param: 'max_tokens',
default: 20,
Expand Down Expand Up @@ -108,50 +100,62 @@ export const CohereChatCompleteConfig: ProviderConfig = {
max: 1,
},
stop: {
param: 'end_sequences',
param: 'stop_sequences',
},
stream: {
param: 'stream',
default: false,
},
seed: {
param: 'seed',
},
logprobs: {
param: 'logprobs',
default: false,
},
};

interface CohereCompleteResponse {
text: string;
generation_id: string;
interface CohereV2CompleteResponse {
id: string;
finish_reason:
| 'COMPLETE'
| 'STOP_SEQUENCE'
| 'ERROR'
| 'ERROR_TOXIC'
| 'ERROR_LIMIT'
| 'USER_CANCEL'
| 'MAX_TOKENS';
meta: {
api_version: {
version: string;
};
billed_units: {
input_tokens: number;
output_tokens: number;
};
| 'MAX_TOKENS'
| 'TOOL_CALL'
| 'ERROR';
message: {
role: 'assistant';
content: Array<{
type: 'text';
text: string;
}>;
};
chat_history?: {
role: 'CHATBOT' | 'SYSTEM' | 'TOOL' | 'USER';
message: string;
}[];
usage: {
input_tokens: number;
output_tokens: number;
};
logprobs?: Array<{
token: string;
logprob: number;
}> | null;
}

interface CohereV2ErrorResponse {
message?: string;
error?: string;
status?: number;
}

export const CohereChatCompleteResponseTransform: (
response: CohereCompleteResponse,
response: CohereV2CompleteResponse | CohereV2ErrorResponse,
responseStatus: number
) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200) {
const errorResponse = response as CohereV2ErrorResponse;
return generateErrorResponse(
{
message: response.message || '',
message:
errorResponse.message || errorResponse.error || 'Unknown error',
type: null,
param: null,
code: null,
Expand All @@ -160,39 +164,70 @@ export const CohereChatCompleteResponseTransform: (
);
}

const successResponse = response as CohereV2CompleteResponse;

const textContent = successResponse.message.content
.filter((c) => c.type === 'text')
.map((c) => c.text)
.join('');

return {
id: response.generation_id,
id: successResponse.id,
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: 'Unknown',
provider: COHERE,
choices: [
{
message: { role: 'assistant', content: response.text },
message: { role: 'assistant', content: textContent },
index: 0,
finish_reason: response.finish_reason,
finish_reason: successResponse.finish_reason,
},
],
usage: {
completion_tokens: response.meta.billed_units.output_tokens,
prompt_tokens: response.meta.billed_units.input_tokens,
completion_tokens: successResponse.usage.output_tokens,
prompt_tokens: successResponse.usage.input_tokens,
total_tokens: Number(
response.meta.billed_units.output_tokens +
response.meta.billed_units.input_tokens
successResponse.usage.output_tokens + successResponse.usage.input_tokens
),
},
};
};

export type CohereStreamChunk =
| { event_type: 'stream-start'; generation_id: string }
| { event_type: 'text-generation'; text: string }
export type CohereV2StreamChunk =
| {
type: 'message-start';
message: {
id: string;
role: 'assistant';
content: Array<any>;
};
}
| {
type: 'content-start';
index: number;
content_block: {
type: 'text';
text: string;
};
}
| {
event_type: 'stream-end';
response_id: string;
response: {
finish_reason: CohereCompleteResponse['finish_reason'];
meta: CohereCompleteResponse['meta'];
type: 'content-delta';
index: number;
delta: {
text: string;
};
}
| {
type: 'content-end';
index: number;
}
| {
type: 'message-end';
message: {
id: string;
finish_reason: CohereV2CompleteResponse['finish_reason'];
usage: CohereV2CompleteResponse['usage'];
};
};

Expand All @@ -212,40 +247,78 @@ export const CohereChatCompleteStreamChunkTransform: (
let chunk = responseChunk.trim();
chunk = chunk.replace(/^data: /, '');
chunk = chunk.trim();
const parsedChunk: CohereStreamChunk = JSON.parse(chunk);
if (parsedChunk.event_type === 'stream-start') {
streamState.generation_id = parsedChunk.generation_id;

if (!chunk || chunk === '[DONE]') {
return `data: [DONE]\n\n`;
}

return (
`data: ${JSON.stringify({
id: streamState?.generation_id ?? fallbackId,
try {
const parsedChunk: CohereV2StreamChunk = JSON.parse(chunk);

if (parsedChunk.type === 'message-start') {
streamState.generation_id = parsedChunk.message.id;
}

const messageId = streamState?.generation_id ?? fallbackId;
let deltaContent = '';
let finishReason = null;
let usage = null;

if (parsedChunk.type === 'content-delta') {
deltaContent = parsedChunk.delta.text;
} else if (parsedChunk.type === 'message-end') {
finishReason = parsedChunk.message.finish_reason;
usage = {
completion_tokens: parsedChunk.message.usage.output_tokens,
prompt_tokens: parsedChunk.message.usage.input_tokens,
total_tokens: Number(
parsedChunk.message.usage.output_tokens +
parsedChunk.message.usage.input_tokens
),
};
}

return (
`data: ${JSON.stringify({
id: messageId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: gatewayRequest.model || '',
provider: COHERE,
...(usage && { usage }),
choices: [
{
index: 0,
delta: {
content: deltaContent,
role: 'assistant',
},
logprobs: null,
finish_reason: finishReason,
},
],
})}` + '\n\n'
);
} catch (error) {
console.error('Error processing Cohere stream chunk:', error);
return `data: ${JSON.stringify({
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: gatewayRequest.model || '',
provider: COHERE,
...(parsedChunk.event_type === 'stream-end' && {
usage: {
completion_tokens:
parsedChunk.response.meta.billed_units.output_tokens,
prompt_tokens: parsedChunk.response.meta.billed_units.input_tokens,
total_tokens: Number(
parsedChunk.response.meta.billed_units.output_tokens +
parsedChunk.response.meta.billed_units.input_tokens
),
},
}),
error: error instanceof Error ? error.message : String(error),
choices: [
{
index: 0,
delta: {
content: (parsedChunk as any)?.text ?? '',
content: '',
role: 'assistant',
},
logprobs: null,
finish_reason: (parsedChunk as any).finish_reason ?? null,
finish_reason: null,
},
],
})}` + '\n\n'
);
})}\n\n`;
}
};