Skip to content

Commit cccd150

Browse files
authored
Merge pull request #1148 from b4s36t4/feat/azure-ai-openai-endpoints
feat: support extra endpoints for azure-ai provider
2 parents 7073b5d + 2ee4b7f commit cccd150

File tree

4 files changed

+285
-7
lines changed

4 files changed

+285
-7
lines changed

src/providers/azure-ai-inference/api.ts

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,28 @@ import {
55
} from '../azure-openai/utils';
66
import { ProviderAPIConfig } from '../types';
77

8+
const NON_INFERENCE_ENDPOINTS = [
9+
'createBatch',
10+
'retrieveBatch',
11+
'cancelBatch',
12+
'getBatchOutput',
13+
'listBatches',
14+
'uploadFile',
15+
'listFiles',
16+
'retrieveFile',
17+
'deleteFile',
18+
'retrieveFileContent',
19+
];
20+
821
const AzureAIInferenceAPI: ProviderAPIConfig = {
9-
getBaseURL: ({ providerOptions }) => {
22+
getBaseURL: ({ providerOptions, fn }) => {
1023
const { provider, azureFoundryUrl } = providerOptions;
24+
25+
// Azure Foundry URL includes `/deployments/<deployment>`, strip out and append openai for batches/finetunes
26+
if (fn && NON_INFERENCE_ENDPOINTS.includes(fn)) {
27+
return new URL(azureFoundryUrl ?? '').origin + '/openai';
28+
}
29+
1130
if (provider === GITHUB) {
1231
return 'https://models.inference.ai.azure.com';
1332
}
@@ -17,7 +36,7 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
1736

1837
return '';
1938
},
20-
headers: async ({ providerOptions }) => {
39+
headers: async ({ providerOptions, fn }) => {
2140
const {
2241
apiKey,
2342
azureExtraParams,
@@ -31,6 +50,13 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
3150
...(azureDeploymentName && {
3251
'azureml-model-deployment': azureDeploymentName,
3352
}),
53+
...(['createTranscription', 'createTranslation', 'uploadFile'].includes(
54+
fn
55+
)
56+
? {
57+
'Content-Type': 'multipart/form-data',
58+
}
59+
: {}),
3460
};
3561
if (azureAdToken) {
3662
headers['Authorization'] =
@@ -70,14 +96,37 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
7096
}
7197
return headers;
7298
},
73-
getEndpoint: ({ providerOptions, fn }) => {
99+
getEndpoint: ({ providerOptions, fn, gatewayRequestURL }) => {
74100
const { azureApiVersion, urlToFetch } = providerOptions;
75101
let mappedFn = fn;
76102

103+
const urlObj = new URL(gatewayRequestURL);
104+
const path = urlObj.pathname.replace('/v1', '');
105+
const searchParams = urlObj.searchParams;
106+
107+
if (azureApiVersion) {
108+
searchParams.set('api-version', azureApiVersion);
109+
}
110+
77111
const ENDPOINT_MAPPING: Record<string, string> = {
78112
complete: '/completions',
79113
chatComplete: '/chat/completions',
80114
embed: '/embeddings',
115+
realtime: '/realtime',
116+
imageGenerate: '/images/generations',
117+
createSpeech: '/audio/speech',
118+
createTranscription: '/audio/transcriptions',
119+
createTranslation: '/audio/translations',
120+
uploadFile: path,
121+
retrieveFile: path,
122+
listFiles: path,
123+
deleteFile: path,
124+
retrieveFileContent: path,
125+
listBatches: path,
126+
retrieveBatch: path,
127+
cancelBatch: path,
128+
getBatchOutput: path,
129+
createBatch: path,
81130
};
82131

83132
const isGithub = providerOptions.provider === GITHUB;
@@ -92,23 +141,40 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
92141
}
93142
}
94143

95-
const apiVersion = azureApiVersion ? `?api-version=${azureApiVersion}` : '';
144+
const searchParamsString = searchParams.toString();
96145
switch (mappedFn) {
97146
case 'complete': {
98147
return isGithub
99148
? ENDPOINT_MAPPING[mappedFn]
100-
: `${ENDPOINT_MAPPING[mappedFn]}${apiVersion}`;
149+
: `${ENDPOINT_MAPPING[mappedFn]}?${searchParamsString}`;
101150
}
102151
case 'chatComplete': {
103152
return isGithub
104153
? ENDPOINT_MAPPING[mappedFn]
105-
: `${ENDPOINT_MAPPING[mappedFn]}${apiVersion}`;
154+
: `${ENDPOINT_MAPPING[mappedFn]}?${searchParamsString}`;
106155
}
107156
case 'embed': {
108157
return isGithub
109158
? ENDPOINT_MAPPING[mappedFn]
110-
: `${ENDPOINT_MAPPING[mappedFn]}${apiVersion}`;
159+
: `${ENDPOINT_MAPPING[mappedFn]}?${searchParamsString}`;
111160
}
161+
case 'realtime':
162+
case 'imageGenerate':
163+
case 'createSpeech':
164+
case 'createTranscription':
165+
case 'createTranslation':
166+
case 'cancelBatch':
167+
case 'createBatch':
168+
case 'getBatchOutput':
169+
case 'retrieveBatch':
170+
case 'listBatches':
171+
case 'retrieveFile':
172+
case 'listFiles':
173+
case 'deleteFile':
174+
case 'retrieveFileContent': {
175+
return `${ENDPOINT_MAPPING[mappedFn]}?${searchParamsString}`;
176+
}
177+
112178
default:
113179
return '';
114180
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import { Context } from 'hono';
2+
import AzureAIInferenceAPI from './api';
3+
import { Options } from '../../types/requestBody';
4+
import { RetrieveBatchResponse } from '../types';
5+
import { AZURE_OPEN_AI } from '../../globals';
6+
7+
// Return a ReadableStream containing batches output data
8+
export const AzureAIInferenceGetBatchOutputRequestHandler = async ({
9+
c,
10+
providerOptions,
11+
requestURL,
12+
}: {
13+
c: Context;
14+
providerOptions: Options;
15+
requestURL: string;
16+
}) => {
17+
// get batch details which has ouptut file id
18+
// get file content as ReadableStream
19+
// return file content
20+
const baseUrl = AzureAIInferenceAPI.getBaseURL({
21+
providerOptions,
22+
fn: 'retrieveBatch',
23+
c,
24+
gatewayRequestURL: requestURL,
25+
});
26+
const retrieveBatchRequestURL = requestURL.replace('/output', '');
27+
const retrieveBatchURL =
28+
baseUrl +
29+
AzureAIInferenceAPI.getEndpoint({
30+
providerOptions,
31+
fn: 'retrieveBatch',
32+
gatewayRequestURL: retrieveBatchRequestURL,
33+
c,
34+
gatewayRequestBodyJSON: {},
35+
gatewayRequestBody: {},
36+
});
37+
const retrieveBatchesHeaders = await AzureAIInferenceAPI.headers({
38+
c,
39+
providerOptions,
40+
fn: 'retrieveBatch',
41+
transformedRequestBody: {},
42+
transformedRequestUrl: retrieveBatchURL,
43+
gatewayRequestBody: {},
44+
});
45+
try {
46+
const retrieveBatchesResponse = await fetch(retrieveBatchURL, {
47+
method: 'GET',
48+
headers: retrieveBatchesHeaders,
49+
});
50+
51+
if (!retrieveBatchesResponse.ok) {
52+
const error = await retrieveBatchesResponse.text();
53+
return new Response(
54+
JSON.stringify({
55+
error: error || 'error fetching batch output',
56+
provider: AZURE_OPEN_AI,
57+
param: null,
58+
}),
59+
{
60+
status: 500,
61+
}
62+
);
63+
}
64+
65+
const batchDetails: RetrieveBatchResponse =
66+
await retrieveBatchesResponse.json();
67+
68+
const outputFileId =
69+
batchDetails.output_file_id || batchDetails.error_file_id;
70+
if (!outputFileId) {
71+
const errors = batchDetails.errors;
72+
if (errors) {
73+
return new Response(JSON.stringify(errors), {
74+
status: 200,
75+
});
76+
}
77+
return new Response(
78+
JSON.stringify({
79+
error: 'invalid response output format',
80+
provider_response: batchDetails,
81+
provider: AZURE_OPEN_AI,
82+
}),
83+
{
84+
status: 400,
85+
}
86+
);
87+
}
88+
const retrieveFileContentRequestURL = `https://api.portkey.ai/v1/files/${outputFileId}/content`; // construct the entire url instead of the path of sanity sake
89+
const retrieveFileContentURL =
90+
baseUrl +
91+
AzureAIInferenceAPI.getEndpoint({
92+
providerOptions,
93+
fn: 'retrieveFileContent',
94+
gatewayRequestURL: retrieveFileContentRequestURL,
95+
c,
96+
gatewayRequestBodyJSON: {},
97+
gatewayRequestBody: {},
98+
});
99+
const retrieveFileContentHeaders = await AzureAIInferenceAPI.headers({
100+
c,
101+
providerOptions,
102+
fn: 'retrieveFileContent',
103+
transformedRequestBody: {},
104+
transformedRequestUrl: retrieveFileContentURL,
105+
gatewayRequestBody: {},
106+
});
107+
const response = fetch(retrieveFileContentURL, {
108+
method: 'GET',
109+
headers: retrieveFileContentHeaders,
110+
});
111+
return response;
112+
} catch (e) {
113+
return new Response(
114+
JSON.stringify({
115+
error: 'error fetching batch output',
116+
provider: AZURE_OPEN_AI,
117+
param: null,
118+
}),
119+
{
120+
status: 500,
121+
}
122+
);
123+
}
124+
};

src/providers/azure-ai-inference/index.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,58 @@ import {
1313
AzureAIInferenceChatCompleteResponseTransform,
1414
} from './chatComplete';
1515
import { AZURE_AI_INFERENCE, GITHUB } from '../../globals';
16+
import { AzureOpenAIImageGenerateConfig } from '../azure-openai/imageGenerate';
17+
import { AzureOpenAICreateSpeechConfig } from '../azure-openai/createSpeech';
18+
import { OpenAICreateFinetuneConfig } from '../openai/createFinetune';
19+
import { AzureOpenAICreateBatchConfig } from '../azure-openai/createBatch';
20+
import { AzureAIInferenceGetBatchOutputRequestHandler } from './getBatchOutput';
21+
import { OpenAIFileUploadRequestTransform } from '../openai/uploadFile';
22+
import {
23+
AzureAIInferenceCreateSpeechResponseTransform,
24+
AzureAIInferenceCreateTranscriptionResponseTransform,
25+
AzureAIInferenceCreateTranslationResponseTransform,
26+
AzureAIInferenceResponseTransform,
27+
} from './utils';
1628

1729
const AzureAIInferenceAPIConfig: ProviderConfigs = {
1830
complete: AzureAIInferenceCompleteConfig,
1931
embed: AzureAIInferenceEmbedConfig,
2032
api: AzureAIInferenceAPI,
2133
chatComplete: AzureAIInferenceChatCompleteConfig,
34+
imageGenerate: AzureOpenAIImageGenerateConfig,
35+
createSpeech: AzureOpenAICreateSpeechConfig,
36+
createFinetune: OpenAICreateFinetuneConfig,
37+
createTranscription: {},
38+
createTranslation: {},
39+
realtime: {},
40+
cancelBatch: {},
41+
createBatch: AzureOpenAICreateBatchConfig,
42+
cancelFinetune: {},
43+
requestHandlers: {
44+
getBatchOutput: AzureAIInferenceGetBatchOutputRequestHandler,
45+
},
46+
requestTransforms: {
47+
uploadFile: OpenAIFileUploadRequestTransform,
48+
},
2249
responseTransforms: {
2350
complete: AzureAIInferenceCompleteResponseTransform(AZURE_AI_INFERENCE),
2451
chatComplete:
2552
AzureAIInferenceChatCompleteResponseTransform(AZURE_AI_INFERENCE),
2653
embed: AzureAIInferenceEmbedResponseTransform(AZURE_AI_INFERENCE),
54+
imageGenerate: AzureAIInferenceResponseTransform,
55+
createSpeech: AzureAIInferenceCreateSpeechResponseTransform,
56+
createTranscription: AzureAIInferenceCreateTranscriptionResponseTransform,
57+
createTranslation: AzureAIInferenceCreateTranslationResponseTransform,
58+
realtime: {},
59+
createBatch: AzureAIInferenceResponseTransform,
60+
retrieveBatch: AzureAIInferenceResponseTransform,
61+
cancelBatch: AzureAIInferenceResponseTransform,
62+
listBatches: AzureAIInferenceResponseTransform,
63+
uploadFile: AzureAIInferenceResponseTransform,
64+
listFiles: AzureAIInferenceResponseTransform,
65+
retrieveFile: AzureAIInferenceResponseTransform,
66+
deleteFile: AzureAIInferenceResponseTransform,
67+
retrieveFileContent: AzureAIInferenceResponseTransform,
2768
},
2869
};
2970

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import { AZURE_AI_INFERENCE } from '../../globals';
2+
import { OpenAIErrorResponseTransform } from '../openai/utils';
3+
import { ErrorResponse } from '../types';
4+
5+
export const AzureAIInferenceResponseTransform = (
6+
response: any,
7+
responseStatus: number
8+
) => {
9+
if (responseStatus !== 200 && 'error' in response) {
10+
return OpenAIErrorResponseTransform(response, AZURE_AI_INFERENCE);
11+
}
12+
13+
return { ...response, provider: AZURE_AI_INFERENCE };
14+
};
15+
16+
export const AzureAIInferenceCreateSpeechResponseTransform = (
17+
response: any,
18+
responseStatus: number
19+
) => {
20+
if (responseStatus !== 200 && 'error' in response) {
21+
return OpenAIErrorResponseTransform(response, AZURE_AI_INFERENCE);
22+
}
23+
24+
return { ...response, provider: AZURE_AI_INFERENCE };
25+
};
26+
27+
export const AzureAIInferenceCreateTranscriptionResponseTransform = (
28+
response: any,
29+
responseStatus: number
30+
) => {
31+
if (responseStatus !== 200 && 'error' in response) {
32+
return OpenAIErrorResponseTransform(response, AZURE_AI_INFERENCE);
33+
}
34+
35+
return { ...response, provider: AZURE_AI_INFERENCE };
36+
};
37+
38+
export const AzureAIInferenceCreateTranslationResponseTransform = (
39+
response: any,
40+
responseStatus: number
41+
) => {
42+
if (responseStatus !== 200 && 'error' in response) {
43+
return OpenAIErrorResponseTransform(response, AZURE_AI_INFERENCE);
44+
}
45+
46+
return { ...response, provider: AZURE_AI_INFERENCE };
47+
};

0 commit comments

Comments
 (0)