Skip to content

Commit 7841989

Browse files
authored
Merge pull request #1314 from narengogi/feat/mistral-on-vertex
mistral on vertex
2 parents 15b0733 + fa30657 commit 7841989

File tree

5 files changed

+115
-99
lines changed

5 files changed

+115
-99
lines changed

src/providers/google-vertex-ai/api.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ export const GoogleApiConfig: ProviderAPIConfig = {
169169
return googleUrlMap.get(mappedFn) || `${projectRoute}`;
170170
}
171171

172+
case 'mistralai':
172173
case 'anthropic': {
173174
if (mappedFn === 'chatComplete' || mappedFn === 'messages') {
174175
return `${projectRoute}/publishers/${provider}/models/${model}:rawPredict`;

src/providers/google-vertex-ai/index.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ import {
5151
VertexAnthropicMessagesConfig,
5252
VertexAnthropicMessagesResponseTransform,
5353
} from './messages';
54+
import {
55+
GetMistralAIChatCompleteResponseTransform,
56+
GetMistralAIChatCompleteStreamChunkTransform,
57+
MistralAIChatCompleteConfig,
58+
} from '../mistral-ai/chatComplete';
5459

5560
const VertexConfig: ProviderConfigs = {
5661
api: VertexApiConfig,
@@ -162,6 +167,17 @@ const VertexConfig: ProviderConfigs = {
162167
...responseTransforms,
163168
},
164169
};
170+
case 'mistralai':
171+
return {
172+
chatComplete: MistralAIChatCompleteConfig,
173+
api: GoogleApiConfig,
174+
responseTransforms: {
175+
chatComplete:
176+
GetMistralAIChatCompleteResponseTransform(GOOGLE_VERTEX_AI),
177+
'stream-chatComplete':
178+
GetMistralAIChatCompleteStreamChunkTransform(GOOGLE_VERTEX_AI),
179+
},
180+
};
165181
default:
166182
return baseConfig;
167183
}

src/providers/google-vertex-ai/utils.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ export const getModelAndProvider = (modelString: string) => {
159159
const modelStringParts = modelString.split('.');
160160
if (
161161
modelStringParts.length > 1 &&
162-
['google', 'anthropic', 'meta', 'endpoints'].includes(modelStringParts[0])
162+
['google', 'anthropic', 'meta', 'endpoints', 'mistralai'].includes(
163+
modelStringParts[0]
164+
)
163165
) {
164166
provider = modelStringParts[0];
165167
model = modelStringParts.slice(1).join('.');

src/providers/mistral-ai/chatComplete.ts

Lines changed: 89 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { MISTRAL_AI } from '../../globals';
21
import { Params } from '../../types/requestBody';
32
import {
43
ChatCompletionResponse,
@@ -17,6 +16,9 @@ export const MistralAIChatCompleteConfig: ProviderConfig = {
1716
param: 'model',
1817
required: true,
1918
default: 'mistral-tiny',
19+
transform: (params: Params) => {
20+
return params.model?.replace('mistralai.', '');
21+
},
2022
},
2123
messages: {
2224
param: 'messages',
@@ -152,104 +154,97 @@ interface MistralAIStreamChunk {
152154
};
153155
}
154156

155-
export const MistralAIChatCompleteResponseTransform: (
156-
response: MistralAIChatCompleteResponse | MistralAIErrorResponse,
157-
responseStatus: number,
158-
responseHeaders: Headers,
159-
strictOpenAiCompliance: boolean,
160-
gatewayRequestUrl: string,
161-
gatewayRequest: Params
162-
) => ChatCompletionResponse | ErrorResponse = (
163-
response,
164-
responseStatus,
165-
responseHeaders,
166-
strictOpenAiCompliance,
167-
gatewayRequestUrl,
168-
gatewayRequest
169-
) => {
170-
if ('message' in response && responseStatus !== 200) {
171-
return generateErrorResponse(
172-
{
173-
message: response.message,
174-
type: response.type,
175-
param: response.param,
176-
code: response.code,
177-
},
178-
MISTRAL_AI
179-
);
180-
}
157+
export const GetMistralAIChatCompleteResponseTransform = (provider: string) => {
158+
return (
159+
response: MistralAIChatCompleteResponse | MistralAIErrorResponse,
160+
responseStatus: number,
161+
_responseHeaders: Headers,
162+
strictOpenAiCompliance: boolean,
163+
_gatewayRequestUrl: string,
164+
_gatewayRequest: Params
165+
): ChatCompletionResponse | ErrorResponse => {
166+
if ('message' in response && responseStatus !== 200) {
167+
return generateErrorResponse(
168+
{
169+
message: response.message,
170+
type: response.type,
171+
param: response.param,
172+
code: response.code,
173+
},
174+
provider
175+
);
176+
}
181177

182-
if ('choices' in response) {
183-
return {
184-
id: response.id,
185-
object: response.object,
186-
created: response.created,
187-
model: response.model,
188-
provider: MISTRAL_AI,
189-
choices: response.choices.map((c) => ({
190-
index: c.index,
191-
message: {
192-
role: c.message.role,
193-
content: c.message.content,
194-
tool_calls: c.message.tool_calls,
178+
if ('choices' in response) {
179+
return {
180+
id: response.id,
181+
object: response.object,
182+
created: response.created,
183+
model: response.model,
184+
provider: provider,
185+
choices: response.choices.map((c) => ({
186+
index: c.index,
187+
message: {
188+
role: c.message.role,
189+
content: c.message.content,
190+
tool_calls: c.message.tool_calls,
191+
},
192+
finish_reason: transformFinishReason(
193+
c.finish_reason as MISTRAL_AI_FINISH_REASON,
194+
strictOpenAiCompliance
195+
),
196+
})),
197+
usage: {
198+
prompt_tokens: response.usage?.prompt_tokens,
199+
completion_tokens: response.usage?.completion_tokens,
200+
total_tokens: response.usage?.total_tokens,
195201
},
196-
finish_reason: transformFinishReason(
197-
c.finish_reason as MISTRAL_AI_FINISH_REASON,
198-
strictOpenAiCompliance
199-
),
200-
})),
201-
usage: {
202-
prompt_tokens: response.usage?.prompt_tokens,
203-
completion_tokens: response.usage?.completion_tokens,
204-
total_tokens: response.usage?.total_tokens,
205-
},
206-
};
207-
}
202+
};
203+
}
208204

209-
return generateInvalidProviderResponseError(response, MISTRAL_AI);
205+
return generateInvalidProviderResponseError(response, provider);
206+
};
210207
};
211208

212-
export const MistralAIChatCompleteStreamChunkTransform: (
213-
response: string,
214-
fallbackId: string,
215-
streamState: any,
216-
strictOpenAiCompliance: boolean,
217-
gatewayRequest: Params
218-
) => string | string[] = (
219-
responseChunk,
220-
fallbackId,
221-
_streamState,
222-
strictOpenAiCompliance,
223-
_gatewayRequest
209+
export const GetMistralAIChatCompleteStreamChunkTransform = (
210+
provider: string
224211
) => {
225-
let chunk = responseChunk.trim();
226-
chunk = chunk.replace(/^data: /, '');
227-
chunk = chunk.trim();
228-
if (chunk === '[DONE]') {
229-
return `data: ${chunk}\n\n`;
230-
}
231-
const parsedChunk: MistralAIStreamChunk = JSON.parse(chunk);
232-
const finishReason = parsedChunk.choices[0].finish_reason
233-
? transformFinishReason(
234-
parsedChunk.choices[0].finish_reason as MISTRAL_AI_FINISH_REASON,
235-
strictOpenAiCompliance
236-
)
237-
: null;
238212
return (
239-
`data: ${JSON.stringify({
240-
id: parsedChunk.id,
241-
object: parsedChunk.object,
242-
created: parsedChunk.created,
243-
model: parsedChunk.model,
244-
provider: MISTRAL_AI,
245-
choices: [
246-
{
247-
index: parsedChunk.choices[0].index,
248-
delta: parsedChunk.choices[0].delta,
249-
finish_reason: finishReason,
250-
},
251-
],
252-
...(parsedChunk.usage ? { usage: parsedChunk.usage } : {}),
253-
})}` + '\n\n'
254-
);
213+
responseChunk: string,
214+
fallbackId: string,
215+
_streamState: any,
216+
strictOpenAiCompliance: boolean,
217+
_gatewayRequest: Params
218+
) => {
219+
let chunk = responseChunk.trim();
220+
chunk = chunk.replace(/^data: /, '');
221+
chunk = chunk.trim();
222+
if (chunk === '[DONE]') {
223+
return `data: ${chunk}\n\n`;
224+
}
225+
const parsedChunk: MistralAIStreamChunk = JSON.parse(chunk);
226+
const finishReason = parsedChunk.choices[0].finish_reason
227+
? transformFinishReason(
228+
parsedChunk.choices[0].finish_reason as MISTRAL_AI_FINISH_REASON,
229+
strictOpenAiCompliance
230+
)
231+
: null;
232+
return (
233+
`data: ${JSON.stringify({
234+
id: parsedChunk.id,
235+
object: parsedChunk.object,
236+
created: parsedChunk.created,
237+
model: parsedChunk.model,
238+
provider: provider,
239+
choices: [
240+
{
241+
index: parsedChunk.choices[0].index,
242+
delta: parsedChunk.choices[0].delta,
243+
finish_reason: finishReason,
244+
},
245+
],
246+
...(parsedChunk.usage ? { usage: parsedChunk.usage } : {}),
247+
})}` + '\n\n'
248+
);
249+
};
255250
};

src/providers/mistral-ai/index.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import { MISTRAL_AI } from '../../globals';
12
import { ProviderConfigs } from '../types';
23
import MistralAIAPIConfig from './api';
34
import {
5+
GetMistralAIChatCompleteResponseTransform,
6+
GetMistralAIChatCompleteStreamChunkTransform,
47
MistralAIChatCompleteConfig,
5-
MistralAIChatCompleteResponseTransform,
6-
MistralAIChatCompleteStreamChunkTransform,
78
} from './chatComplete';
89
import { MistralAIEmbedConfig, MistralAIEmbedResponseTransform } from './embed';
910

@@ -12,8 +13,9 @@ const MistralAIConfig: ProviderConfigs = {
1213
embed: MistralAIEmbedConfig,
1314
api: MistralAIAPIConfig,
1415
responseTransforms: {
15-
chatComplete: MistralAIChatCompleteResponseTransform,
16-
'stream-chatComplete': MistralAIChatCompleteStreamChunkTransform,
16+
chatComplete: GetMistralAIChatCompleteResponseTransform(MISTRAL_AI),
17+
'stream-chatComplete':
18+
GetMistralAIChatCompleteStreamChunkTransform(MISTRAL_AI),
1719
embed: MistralAIEmbedResponseTransform,
1820
},
1921
};

0 commit comments

Comments
 (0)