Skip to content

Commit 7988e9e

Browse files
Merge branch 'main' into fix-readme
2 parents d66a66b + 29c13cc commit 7988e9e

File tree

8 files changed

+308
-12
lines changed

8 files changed

+308
-12
lines changed

plugins/default/webhook.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import {
44
PluginHandler,
55
PluginParameters,
66
} from '../types';
7-
import { post } from '../utils';
7+
import { post, TimeoutError } from '../utils';
88

99
function parseHeaders(headers: unknown): Record<string, string> {
1010
try {
@@ -107,20 +107,30 @@ export const handler: PluginHandler = async (
107107
responseData: response.data,
108108
requestContext: {
109109
headers,
110-
timeout: 3000,
110+
timeout: parameters.timeout || 3000,
111111
},
112112
};
113113
} catch (e: any) {
114114
error = e;
115115
delete error.stack;
116116

117+
const isTimeoutError = e instanceof TimeoutError;
118+
119+
const responseData = !isTimeoutError && e.response?.body;
120+
const responseDataContentType = e.response?.headers?.get('content-type');
121+
117122
data = {
118123
explanation: `Webhook error: ${e.message}`,
119124
webhookUrl: parameters.webhookURL || 'No URL provided',
120125
requestContext: {
121126
headers: parameters.headers || {},
122-
timeout: 3000,
127+
timeout: parameters.timeout || 3000,
123128
},
129+
// return response body if it's not a ok response and not a timeout error
130+
...(responseData &&
131+
responseDataContentType === 'application/json' && {
132+
responseData: JSON.parse(responseData),
133+
}),
124134
};
125135
}
126136

plugins/utils.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export interface ErrorResponse {
99
status: number;
1010
statusText: string;
1111
body: string;
12+
headers?: Headers;
1213
}
1314

1415
export class HttpError extends Error {
@@ -21,7 +22,7 @@ export class HttpError extends Error {
2122
}
2223
}
2324

24-
class TimeoutError extends Error {
25+
export class TimeoutError extends Error {
2526
url: string;
2627
timeout: number;
2728
method: string;
@@ -222,6 +223,7 @@ export async function post<T = any>(
222223
status: response.status,
223224
statusText: response.statusText,
224225
body: errorBody,
226+
headers: response.headers,
225227
};
226228

227229
throw new HttpError(

src/middlewares/hooks/index.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ export class HooksManager {
306306
transformed: result.transformed || false,
307307
created_at: createdAt,
308308
log: result.log || null,
309+
fail_on_error:
310+
(check.parameters as Record<string, any>)?.failOnError || false,
309311
};
310312
} catch (err: any) {
311313
console.error(`Error executing check "${check.id}":`, err);
@@ -390,7 +392,10 @@ export class HooksManager {
390392
}
391393

392394
hookResult = {
393-
verdict: checkResults.every((result) => result.verdict || result.error),
395+
// if guardrail has error, make verdict false else do the normal check
396+
verdict: checkResults.every(
397+
(result) => result.verdict || (result.error && !result.fail_on_error)
398+
),
394399
id: hook.id,
395400
transformed: checkResults.some((result) => result.transformed),
396401
checks: checkResults,

src/middlewares/hooks/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ export interface GuardrailCheckResult {
8080
};
8181
};
8282
log?: any;
83+
fail_on_error?: boolean;
8384
}
8485

8586
export interface GuardrailResult {

src/providers/azure-ai-inference/api.ts

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,28 @@ import {
55
} from '../azure-openai/utils';
66
import { ProviderAPIConfig } from '../types';
77

8+
const NON_INFERENCE_ENDPOINTS = [
9+
'createBatch',
10+
'retrieveBatch',
11+
'cancelBatch',
12+
'getBatchOutput',
13+
'listBatches',
14+
'uploadFile',
15+
'listFiles',
16+
'retrieveFile',
17+
'deleteFile',
18+
'retrieveFileContent',
19+
];
20+
821
const AzureAIInferenceAPI: ProviderAPIConfig = {
9-
getBaseURL: ({ providerOptions }) => {
22+
getBaseURL: ({ providerOptions, fn }) => {
1023
const { provider, azureFoundryUrl } = providerOptions;
24+
25+
// Azure Foundry URL includes `/deployments/<deployment>`, strip out and append openai for batches/finetunes
26+
if (fn && NON_INFERENCE_ENDPOINTS.includes(fn)) {
27+
return new URL(azureFoundryUrl ?? '').origin + '/openai';
28+
}
29+
1130
if (provider === GITHUB) {
1231
return 'https://models.inference.ai.azure.com';
1332
}
@@ -17,7 +36,7 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
1736

1837
return '';
1938
},
20-
headers: async ({ providerOptions }) => {
39+
headers: async ({ providerOptions, fn }) => {
2140
const {
2241
apiKey,
2342
azureExtraParams,
@@ -31,6 +50,13 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
3150
...(azureDeploymentName && {
3251
'azureml-model-deployment': azureDeploymentName,
3352
}),
53+
...(['createTranscription', 'createTranslation', 'uploadFile'].includes(
54+
fn
55+
)
56+
? {
57+
'Content-Type': 'multipart/form-data',
58+
}
59+
: {}),
3460
};
3561
if (azureAdToken) {
3662
headers['Authorization'] =
@@ -70,14 +96,37 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
7096
}
7197
return headers;
7298
},
73-
getEndpoint: ({ providerOptions, fn }) => {
99+
getEndpoint: ({ providerOptions, fn, gatewayRequestURL }) => {
74100
const { azureApiVersion, urlToFetch } = providerOptions;
75101
let mappedFn = fn;
76102

103+
const urlObj = new URL(gatewayRequestURL);
104+
const path = urlObj.pathname.replace('/v1', '');
105+
const searchParams = urlObj.searchParams;
106+
107+
if (azureApiVersion) {
108+
searchParams.set('api-version', azureApiVersion);
109+
}
110+
77111
const ENDPOINT_MAPPING: Record<string, string> = {
78112
complete: '/completions',
79113
chatComplete: '/chat/completions',
80114
embed: '/embeddings',
115+
realtime: '/realtime',
116+
imageGenerate: '/images/generations',
117+
createSpeech: '/audio/speech',
118+
createTranscription: '/audio/transcriptions',
119+
createTranslation: '/audio/translations',
120+
uploadFile: path,
121+
retrieveFile: path,
122+
listFiles: path,
123+
deleteFile: path,
124+
retrieveFileContent: path,
125+
listBatches: path,
126+
retrieveBatch: path,
127+
cancelBatch: path,
128+
getBatchOutput: path,
129+
createBatch: path,
81130
};
82131

83132
const isGithub = providerOptions.provider === GITHUB;
@@ -92,23 +141,40 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
92141
}
93142
}
94143

95-
const apiVersion = azureApiVersion ? `?api-version=${azureApiVersion}` : '';
144+
const searchParamsString = searchParams.toString();
96145
switch (mappedFn) {
97146
case 'complete': {
98147
return isGithub
99148
? ENDPOINT_MAPPING[mappedFn]
100-
: `${ENDPOINT_MAPPING[mappedFn]}${apiVersion}`;
149+
: `${ENDPOINT_MAPPING[mappedFn]}?${searchParamsString}`;
101150
}
102151
case 'chatComplete': {
103152
return isGithub
104153
? ENDPOINT_MAPPING[mappedFn]
105-
: `${ENDPOINT_MAPPING[mappedFn]}${apiVersion}`;
154+
: `${ENDPOINT_MAPPING[mappedFn]}?${searchParamsString}`;
106155
}
107156
case 'embed': {
108157
return isGithub
109158
? ENDPOINT_MAPPING[mappedFn]
110-
: `${ENDPOINT_MAPPING[mappedFn]}${apiVersion}`;
159+
: `${ENDPOINT_MAPPING[mappedFn]}?${searchParamsString}`;
111160
}
161+
case 'realtime':
162+
case 'imageGenerate':
163+
case 'createSpeech':
164+
case 'createTranscription':
165+
case 'createTranslation':
166+
case 'cancelBatch':
167+
case 'createBatch':
168+
case 'getBatchOutput':
169+
case 'retrieveBatch':
170+
case 'listBatches':
171+
case 'retrieveFile':
172+
case 'listFiles':
173+
case 'deleteFile':
174+
case 'retrieveFileContent': {
175+
return `${ENDPOINT_MAPPING[mappedFn]}?${searchParamsString}`;
176+
}
177+
112178
default:
113179
return '';
114180
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import { Context } from 'hono';
2+
import AzureAIInferenceAPI from './api';
3+
import { Options } from '../../types/requestBody';
4+
import { RetrieveBatchResponse } from '../types';
5+
import { AZURE_OPEN_AI } from '../../globals';
6+
7+
// Return a ReadableStream containing batches output data
8+
export const AzureAIInferenceGetBatchOutputRequestHandler = async ({
9+
c,
10+
providerOptions,
11+
requestURL,
12+
}: {
13+
c: Context;
14+
providerOptions: Options;
15+
requestURL: string;
16+
}) => {
17+
// get batch details which has ouptut file id
18+
// get file content as ReadableStream
19+
// return file content
20+
const baseUrl = AzureAIInferenceAPI.getBaseURL({
21+
providerOptions,
22+
fn: 'retrieveBatch',
23+
c,
24+
gatewayRequestURL: requestURL,
25+
});
26+
const retrieveBatchRequestURL = requestURL.replace('/output', '');
27+
const retrieveBatchURL =
28+
baseUrl +
29+
AzureAIInferenceAPI.getEndpoint({
30+
providerOptions,
31+
fn: 'retrieveBatch',
32+
gatewayRequestURL: retrieveBatchRequestURL,
33+
c,
34+
gatewayRequestBodyJSON: {},
35+
gatewayRequestBody: {},
36+
});
37+
const retrieveBatchesHeaders = await AzureAIInferenceAPI.headers({
38+
c,
39+
providerOptions,
40+
fn: 'retrieveBatch',
41+
transformedRequestBody: {},
42+
transformedRequestUrl: retrieveBatchURL,
43+
gatewayRequestBody: {},
44+
});
45+
try {
46+
const retrieveBatchesResponse = await fetch(retrieveBatchURL, {
47+
method: 'GET',
48+
headers: retrieveBatchesHeaders,
49+
});
50+
51+
if (!retrieveBatchesResponse.ok) {
52+
const error = await retrieveBatchesResponse.text();
53+
return new Response(
54+
JSON.stringify({
55+
error: error || 'error fetching batch output',
56+
provider: AZURE_OPEN_AI,
57+
param: null,
58+
}),
59+
{
60+
status: 500,
61+
}
62+
);
63+
}
64+
65+
const batchDetails: RetrieveBatchResponse =
66+
await retrieveBatchesResponse.json();
67+
68+
const outputFileId =
69+
batchDetails.output_file_id || batchDetails.error_file_id;
70+
if (!outputFileId) {
71+
const errors = batchDetails.errors;
72+
if (errors) {
73+
return new Response(JSON.stringify(errors), {
74+
status: 200,
75+
});
76+
}
77+
return new Response(
78+
JSON.stringify({
79+
error: 'invalid response output format',
80+
provider_response: batchDetails,
81+
provider: AZURE_OPEN_AI,
82+
}),
83+
{
84+
status: 400,
85+
}
86+
);
87+
}
88+
const retrieveFileContentRequestURL = `https://api.portkey.ai/v1/files/${outputFileId}/content`; // construct the entire url instead of the path of sanity sake
89+
const retrieveFileContentURL =
90+
baseUrl +
91+
AzureAIInferenceAPI.getEndpoint({
92+
providerOptions,
93+
fn: 'retrieveFileContent',
94+
gatewayRequestURL: retrieveFileContentRequestURL,
95+
c,
96+
gatewayRequestBodyJSON: {},
97+
gatewayRequestBody: {},
98+
});
99+
const retrieveFileContentHeaders = await AzureAIInferenceAPI.headers({
100+
c,
101+
providerOptions,
102+
fn: 'retrieveFileContent',
103+
transformedRequestBody: {},
104+
transformedRequestUrl: retrieveFileContentURL,
105+
gatewayRequestBody: {},
106+
});
107+
const response = fetch(retrieveFileContentURL, {
108+
method: 'GET',
109+
headers: retrieveFileContentHeaders,
110+
});
111+
return response;
112+
} catch (e) {
113+
return new Response(
114+
JSON.stringify({
115+
error: 'error fetching batch output',
116+
provider: AZURE_OPEN_AI,
117+
param: null,
118+
}),
119+
{
120+
status: 500,
121+
}
122+
);
123+
}
124+
};

0 commit comments

Comments
 (0)