|
| 1 | +import { BYTEZ } from '../../globals'; |
1 | 2 | import { ProviderConfigs } from '../types';
|
| 3 | +import { generateErrorResponse } from '../utils'; |
2 | 4 | import BytezInferenceAPI from './api';
|
3 | 5 | import { BytezInferenceChatCompleteConfig } from './chatComplete';
|
4 |
| -import { bodyAdapter, LRUCache } from './utils'; |
5 | 6 | import { BytezResponse } from './types';
|
6 | 7 |
|
7 |
| -const BASE_URL = 'https://api.bytez.com/models/v2'; |
8 |
| - |
9 |
| -const IS_CHAT_MODEL_CACHE = new LRUCache({ size: 100 }); |
10 |
| - |
11 | 8 | const BytezInferenceAPIConfig: ProviderConfigs = {
|
12 | 9 | api: BytezInferenceAPI,
|
13 | 10 | chatComplete: BytezInferenceChatCompleteConfig,
|
14 |
| - requestHandlers: { |
15 |
| - chatComplete: async ({ providerOptions, requestBody }) => { |
16 |
| - try { |
17 |
| - const { model: modelId } = requestBody; |
18 |
| - |
19 |
| - const adaptedBody = bodyAdapter(requestBody); |
20 |
| - |
21 |
| - const headers = { |
22 |
| - 'Content-Type': 'application/json', |
23 |
| - Authorization: `Key ${providerOptions.apiKey}`, |
24 |
| - }; |
25 |
| - |
26 |
| - const isChatModel = await validateModelIsChat(modelId, headers); |
27 |
| - |
28 |
| - if (!isChatModel) { |
29 |
| - return constructFailureResponse( |
30 |
| - 'Bytez only supports chat models on PortKey', |
31 |
| - { status: 400 } |
32 |
| - ); |
33 |
| - } |
34 |
| - |
35 |
| - const url = `${BASE_URL}/${modelId}`; |
36 |
| - |
37 |
| - const response = await fetch(url, { |
38 |
| - method: 'POST', |
39 |
| - headers, |
40 |
| - body: JSON.stringify(adaptedBody), |
41 |
| - }); |
42 |
| - |
43 |
| - if (adaptedBody.stream) { |
44 |
| - return new Response(response.body, response); |
45 |
| - } |
46 |
| - |
47 |
| - const { error, output }: BytezResponse = await response.json(); |
48 |
| - |
49 |
| - if (error) { |
50 |
| - return constructFailureResponse(error, response); |
51 |
| - } |
52 |
| - |
53 |
| - return new Response( |
54 |
| - JSON.stringify({ |
55 |
| - id: crypto.randomUUID(), |
56 |
| - object: 'chat.completion', |
57 |
| - created: Date.now(), |
58 |
| - model: modelId, |
59 |
| - choices: [ |
60 |
| - { |
61 |
| - index: 0, |
62 |
| - message: output, |
63 |
| - logprobs: null, |
64 |
| - finish_reason: 'stop', |
65 |
| - }, |
66 |
| - ], |
67 |
| - usage: { |
68 |
| - inferenceTime: response.headers.get('inference-time'), |
69 |
| - modelSize: response.headers.get('inference-meter'), |
70 |
| - }, |
71 |
| - }), |
72 |
| - response |
| 11 | + responseTransforms: { |
| 12 | + chatComplete: ( |
| 13 | + response: BytezResponse, |
| 14 | + responseStatus: number, |
| 15 | + responseHeaders: any, |
| 16 | + strictOpenAiCompliance: boolean, |
| 17 | + endpoint: string, |
| 18 | + requestBody: any |
| 19 | + ) => { |
| 20 | + const { error, output } = response; |
| 21 | + |
| 22 | + if (error) { |
| 23 | + return generateErrorResponse( |
| 24 | + { |
| 25 | + message: error, |
| 26 | + type: String(responseStatus), |
| 27 | + param: null, |
| 28 | + code: null, |
| 29 | + }, |
| 30 | + BYTEZ |
73 | 31 | );
|
74 |
| - } catch (error: any) { |
75 |
| - return constructFailureResponse(error.message); |
76 | 32 | }
|
| 33 | + |
| 34 | + return { |
| 35 | + id: crypto.randomUUID(), |
| 36 | + object: 'chat.completion', |
| 37 | + created: Date.now(), |
| 38 | + model: requestBody.model, |
| 39 | + choices: [ |
| 40 | + { |
| 41 | + index: 0, |
| 42 | + message: output, |
| 43 | + logprobs: null, |
| 44 | + finish_reason: 'stop', |
| 45 | + }, |
| 46 | + ], |
| 47 | + usage: { |
| 48 | + inferenceTime: responseHeaders.get('inference-time'), |
| 49 | + modelSize: responseHeaders.get('inference-meter'), |
| 50 | + }, |
| 51 | + }; |
77 | 52 | },
|
78 | 53 | },
|
79 | 54 | };
|
80 | 55 |
|
81 |
| -async function validateModelIsChat( |
82 |
| - modelId: string, |
83 |
| - headers: Record<string, any> |
84 |
| -) { |
85 |
| - // return from cache if already validated |
86 |
| - if (IS_CHAT_MODEL_CACHE.has(modelId)) { |
87 |
| - return IS_CHAT_MODEL_CACHE.get(modelId); |
88 |
| - } |
89 |
| - |
90 |
| - const url = `${BASE_URL}/list/models?modelId=${modelId}`; |
91 |
| - |
92 |
| - const response = await fetch(url, { |
93 |
| - headers, |
94 |
| - }); |
95 |
| - |
96 |
| - const { |
97 |
| - error, |
98 |
| - output: [model], |
99 |
| - }: BytezResponse = await response.json(); |
100 |
| - |
101 |
| - if (error) { |
102 |
| - throw new Error(error); |
103 |
| - } |
104 |
| - |
105 |
| - const isChatModel = model.task === 'chat'; |
106 |
| - |
107 |
| - IS_CHAT_MODEL_CACHE.set(modelId, isChatModel); |
108 |
| - |
109 |
| - return isChatModel; |
110 |
| -} |
111 |
| - |
112 |
| -function constructFailureResponse(message: string, response?: object) { |
113 |
| - return new Response( |
114 |
| - JSON.stringify({ |
115 |
| - status: 'failure', |
116 |
| - message, |
117 |
| - }), |
118 |
| - { |
119 |
| - status: 500, |
120 |
| - headers: { |
121 |
| - 'content-type': 'application/json', |
122 |
| - }, |
123 |
| - // override defaults if desired |
124 |
| - ...response, |
125 |
| - } |
126 |
| - ); |
127 |
| -} |
128 |
| - |
129 | 56 | export default BytezInferenceAPIConfig;
|
0 commit comments