Skip to content

Commit 8fff80d

Browse files
committed
Added models handler to Albus
1 parent 1a1b881 commit 8fff80d

File tree

3 files changed

+48
-30
lines changed

3 files changed

+48
-30
lines changed

src/globals.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export const POSSIBLE_RETRY_STATUS_HEADERS = [
1111
];
1212

1313
export const HEADER_KEYS: Record<string, string> = {
14+
API_KEY: `x-${POWERED_BY}-api-key`,
1415
MODE: `x-${POWERED_BY}-mode`,
1516
RETRIES: `x-${POWERED_BY}-retry-count`,
1617
PROVIDER: `x-${POWERED_BY}-provider`,
@@ -23,6 +24,7 @@ export const HEADER_KEYS: Record<string, string> = {
2324
REQUEST_TIMEOUT: `x-${POWERED_BY}-request-timeout`,
2425
STRICT_OPEN_AI_COMPLIANCE: `x-${POWERED_BY}-strict-open-ai-compliance`,
2526
CONTENT_TYPE: `Content-Type`,
27+
VIRTUAL_KEY: `x-${POWERED_BY}-virtual-key`,
2628
};
2729

2830
export const RESPONSE_HEADER_KEYS: Record<string, string> = {

src/handlers/modelsHandler.ts

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,53 @@
1-
import { Context } from 'hono';
2-
import models from '../data/models.json';
3-
import providers from '../data/providers.json';
1+
import { Context, Next } from 'hono';
2+
import { HEADER_KEYS } from '../globals';
3+
import { env } from 'hono/adapter';
44

55
/**
66
* Handles the models request. Returns a list of models supported by the Ai gateway.
77
* Allows filters in query params for the provider
88
* @param c - The Hono context
99
* @returns - The response
1010
*/
11-
export async function modelsHandler(c: Context): Promise<Response> {
12-
// If the request does not contain a provider query param, return all models. Add a count as well.
13-
const provider = c.req.query('provider');
14-
if (!provider) {
15-
return c.json({
16-
...models,
17-
count: models.data.length,
18-
});
19-
} else {
20-
// Filter the models by the provider
21-
const filteredModels = models.data.filter(
22-
(model: any) => model.provider.id === provider
23-
);
24-
return c.json({
25-
...models,
26-
data: filteredModels,
27-
count: filteredModels.length,
28-
});
11+
export const modelsHandler = async (context: Context, next: Next) => {
12+
const fetchOptions: Record<string, any> = {};
13+
fetchOptions['method'] = context.req.method;
14+
15+
const headers = Object.fromEntries(context.req.raw.headers);
16+
17+
const authHeader = headers['Authorization'] || headers['authorization'];
18+
19+
const apiKey =
20+
headers[HEADER_KEYS.API_KEY] || authHeader?.replace('Bearer ', '');
21+
let config: any = headers[HEADER_KEYS.CONFIG];
22+
if (config && typeof config === 'string') {
23+
try {
24+
config = JSON.parse(config);
25+
} catch {
26+
config = {};
27+
}
2928
}
30-
}
29+
const providerHeader = headers[HEADER_KEYS.PROVIDER];
30+
const virtualKey = headers[HEADER_KEYS.VIRTUAL_KEY];
31+
32+
const containsProvider =
33+
providerHeader || virtualKey || config?.provider || config?.virtual_key;
34+
35+
if (containsProvider) {
36+
return next();
37+
}
38+
39+
// Strip gateway endpoint for models endpoint.
40+
const urlObject = new URL(context.req.url);
41+
const requestRoute = `${env(context).ALBUS_BASEPATH}${context.req.path.replace('/v1/', '/v2/')}${urlObject.search}`;
42+
fetchOptions['headers'] = {
43+
[HEADER_KEYS.API_KEY]: apiKey,
44+
};
3145

32-
export async function providersHandler(c: Context): Promise<Response> {
33-
return c.json({
34-
...providers,
35-
count: providers.data.length,
46+
const resp = await fetch(requestRoute, fetchOptions);
47+
return new Response(resp.body, {
48+
status: resp.status,
49+
headers: {
50+
'content-type': 'application/json',
51+
},
3652
});
37-
}
53+
};

src/index.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ if (getRuntimeKey() === 'node') {
9191
app.use(logger());
9292
}
9393

94+
// Support the /v1/models endpoint
95+
app.get('/v1/models', modelsHandler);
96+
9497
// Use hooks middleware for all routes
9598
app.use('*', hooks);
9699

@@ -252,9 +255,6 @@ app.post('/v1/prompts/*', requestValidator, (c) => {
252255
});
253256
});
254257

255-
app.get('/v1/reference/models', modelsHandler);
256-
app.get('/v1/reference/providers', providersHandler);
257-
258258
// WebSocket route
259259
if (runtime === 'workerd') {
260260
app.get('/v1/realtime', realTimeHandler);

0 commit comments

Comments
 (0)