From ef99ad6df456053302a7752448dfbe90626cfd7f Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Mon, 22 Sep 2025 19:07:35 +0530 Subject: [PATCH 01/25] Rate limits and budgets fix errors --- initializeSettings.ts | 66 +++ package-lock.json | 92 +++- package.json | 1 + settings.example.json | 32 ++ src/globals.ts | 13 + src/handlers/services/responseService.ts | 10 +- src/index.ts | 23 +- src/services/realtimeLlmEventParser.ts | 13 +- .../services/cache/backends/cloudflareKV.ts | 228 ++++++++ src/shared/services/cache/backends/file.ts | 321 ++++++++++++ src/shared/services/cache/backends/memory.ts | 220 ++++++++ src/shared/services/cache/backends/redis.ts | 252 +++++++++ src/shared/services/cache/index.ts | 486 ++++++++++++++++++ src/shared/services/cache/types.ts | 57 ++ .../services/cache/utils/rateLimiter.ts | 182 +++++++ src/shared/utils/logger.ts | 128 +++++ src/utils/misc.ts | 13 + wrangler.toml | 12 + 18 files changed, 2135 insertions(+), 14 deletions(-) create mode 100644 initializeSettings.ts create mode 100644 settings.example.json create mode 100644 src/shared/services/cache/backends/cloudflareKV.ts create mode 100644 src/shared/services/cache/backends/file.ts create mode 100644 src/shared/services/cache/backends/memory.ts create mode 100644 src/shared/services/cache/backends/redis.ts create mode 100644 src/shared/services/cache/index.ts create mode 100644 src/shared/services/cache/types.ts create mode 100644 src/shared/services/cache/utils/rateLimiter.ts create mode 100644 src/shared/utils/logger.ts diff --git a/initializeSettings.ts b/initializeSettings.ts new file mode 100644 index 000000000..2ee0c6727 --- /dev/null +++ b/initializeSettings.ts @@ -0,0 +1,66 @@ +const organisationDetails = { + id: '00000000-0000-0000-0000-000000000000', + name: 'Portkey self hosted', + settings: { + debug_log: 1, + is_virtual_key_limit_enabled: 1, + allowed_guardrails: ['BASIC'], + }, + workspaceDetails: {}, + defaults: { + metadata: null, + }, + usageLimits: [], + rateLimits: [], + organisationDefaults: { + input_guardrails: null, + }, +}; + +const transformIntegrations = (integrations: any) => { + return integrations.map((integration: any) => { + return { + id: '1234567890', //need to do consistent hashing for caching + ai_provider_name: integration.provider, + model_config: { + ...integration.credentials, + }, + ...(integration.credentials?.apiKey && { + key: integration.credentials.apiKey, + }), + slug: integration.slug, + usage_limits: null, + status: 'active', + integration_id: '1234567890', + object: 'virtual-key', + integration_details: { + id: '1234567890', + slug: integration.slug, + usage_limits: integration.usage_limits, + rate_limits: integration.rate_limits, + models: integration.models, + }, + }; + }); +}; + +let settings: any = {}; +try { + // @ts-expect-error + const settingsFile = await import('./settings.json'); + if (!settingsFile) { + settings = undefined; + } else { + settings.organisationDetails = organisationDetails; + if (settingsFile.integrations) { + settings.integrations = transformIntegrations(settingsFile.integrations); + } + } +} catch (error) { + console.log( + 'WARNING: Unable to import settings from the path, please make sure the file exists', + error + ); +} + +export { settings }; diff --git a/package-lock.json b/package-lock.json index 01750ed44..383daf4f2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -20,6 +20,7 @@ "async-retry": "^1.3.3", "avsc": "^5.7.7", "hono": "^4.6.10", + "ioredis": "^5.8.0", "jose": "^6.0.11", "patch-package": "^8.0.0", "ws": "^8.18.0", @@ -1412,6 +1413,12 @@ "url": "https://github.com/sponsors/nzakas" } }, + "node_modules/@ioredis/commands": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@ioredis/commands/-/commands-1.4.0.tgz", + "integrity": "sha512-aFT2yemJJo+TZCmieA7qnYGQooOS7QfNmYrzGtsYd3g9j5iDP8AimYYAesf79ohjbLG12XxC4nG5DyEnC88AsQ==", + "license": "MIT" + }, "node_modules/@istanbuljs/load-nyc-config": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@istanbuljs/load-nyc-config/-/load-nyc-config-1.1.0.tgz", @@ -3239,6 +3246,15 @@ "node": ">=12" } }, + "node_modules/cluster-key-slot": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/cluster-key-slot/-/cluster-key-slot-1.1.2.tgz", + "integrity": "sha512-RMr0FhtfXemyinomL4hrWcYJxmX6deFdCxpJzhDttxgO1+bcCnkk+9drydLVDmAMG7NE6aN/fl4F7ucU/90gAA==", + "license": "Apache-2.0", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/co": { "version": "4.6.0", "resolved": "https://registry.npmjs.org/co/-/co-4.6.0.tgz", @@ -3373,7 +3389,6 @@ "version": "4.3.4", "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", - "dev": true, "dependencies": { "ms": "2.1.2" }, @@ -3449,6 +3464,15 @@ "node": ">=0.4.0" } }, + "node_modules/denque": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/denque/-/denque-2.1.0.tgz", + "integrity": "sha512-HVQE3AAb/pxF8fQAoiqpvg9i3evqug3hoiwakOyZAwJm+6vZehbkYXZ0l4JxS+I3QxM97v5aaRNhj8v5oBhekw==", + "license": "Apache-2.0", + "engines": { + "node": ">=0.10" + } + }, "node_modules/detect-newline": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/detect-newline/-/detect-newline-3.1.0.tgz", @@ -4614,6 +4638,30 @@ "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==" }, + "node_modules/ioredis": { + "version": "5.8.0", + "resolved": "https://registry.npmjs.org/ioredis/-/ioredis-5.8.0.tgz", + "integrity": "sha512-AUXbKn9gvo9hHKvk6LbZJQSKn/qIfkWXrnsyL9Yrf+oeXmla9Nmf6XEumOddyhM8neynpK5oAV6r9r99KBuwzA==", + "license": "MIT", + "dependencies": { + "@ioredis/commands": "1.4.0", + "cluster-key-slot": "^1.1.0", + "debug": "^4.3.4", + "denque": "^2.1.0", + "lodash.defaults": "^4.2.0", + "lodash.isarguments": "^3.1.0", + "redis-errors": "^1.2.0", + "redis-parser": "^3.0.0", + "standard-as-callback": "^2.1.0" + }, + "engines": { + "node": ">=12.22.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/ioredis" + } + }, "node_modules/is-arrayish": { "version": "0.2.1", "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.2.1.tgz", @@ -5607,6 +5655,18 @@ "node": ">=8" } }, + "node_modules/lodash.defaults": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/lodash.defaults/-/lodash.defaults-4.2.0.tgz", + "integrity": "sha512-qjxPLHd3r5DnsdGacqOMU6pb/avJzdh9tFX2ymgoZE27BmjXrNy/y4LoaiTeAb+O3gL8AfpJGtqfX/ae2leYYQ==", + "license": "MIT" + }, + "node_modules/lodash.isarguments": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/lodash.isarguments/-/lodash.isarguments-3.1.0.tgz", + "integrity": "sha512-chi4NHZlZqZD18a0imDHnZPrDeBbTtVN7GXMwuGdRH9qotxAjYs3aVLKc7zNOG9eddR5Ksd8rvFEBc9SsggPpg==", + "license": "MIT" + }, "node_modules/lodash.memoize": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/lodash.memoize/-/lodash.memoize-4.1.2.tgz", @@ -5822,8 +5882,7 @@ "node_modules/ms": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", - "dev": true + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" }, "node_modules/mustache": { "version": "4.2.0", @@ -6482,6 +6541,27 @@ "url": "https://paulmillr.com/funding/" } }, + "node_modules/redis-errors": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/redis-errors/-/redis-errors-1.2.0.tgz", + "integrity": "sha512-1qny3OExCf0UvUV/5wpYKf2YwPcOqXzkwKKSmKHiE6ZMQs5heeE/c8eXK+PNllPvmjgAbfnsbpkGZWy8cBpn9w==", + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/redis-parser": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-3.0.0.tgz", + "integrity": "sha512-DJnGAeenTdpMEH6uAJRK/uiyEIH9WVsUmoLwzudwGJUwZPp80PDBWPHXSAGNPwNvIXAbe7MSUB1zQFugFml66A==", + "license": "MIT", + "dependencies": { + "redis-errors": "^1.0.0" + }, + "engines": { + "node": ">=4" + } + }, "node_modules/require-directory": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz", @@ -6883,6 +6963,12 @@ "get-source": "^2.0.12" } }, + "node_modules/standard-as-callback": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/standard-as-callback/-/standard-as-callback-2.1.0.tgz", + "integrity": "sha512-qoRRSyROncaz1z0mvYqIE4lCd9p2R90i6GxW3uZv5ucSu8tU7B5HXUP1gG8pVZsYNVaXjk8ClXHPttLyxAL48A==", + "license": "MIT" + }, "node_modules/stoppable": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/stoppable/-/stoppable-1.1.0.tgz", diff --git a/package.json b/package.json index 0fea243ed..982e19982 100644 --- a/package.json +++ b/package.json @@ -51,6 +51,7 @@ "async-retry": "^1.3.3", "avsc": "^5.7.7", "hono": "^4.6.10", + "ioredis": "^5.8.0", "jose": "^6.0.11", "patch-package": "^8.0.0", "ws": "^8.18.0", diff --git a/settings.example.json b/settings.example.json new file mode 100644 index 000000000..5d9790d42 --- /dev/null +++ b/settings.example.json @@ -0,0 +1,32 @@ +{ + "integrations": [ + { + "provider": "anthropic", + "slug": "dev_team_anthropic", + "credentials": { + "apiKey": "sk-ant-" + }, + "rate_limits": [ + { + "type": "requests", + "unit": "rph", + "value": 3 + } + ], + "usage_limits": [ + { + "type": "tokens", + "credit_limit": 1000000, + "periodic_reset": "weekly" + } + ], + "models": [ + { + "slug": "claude-3-7-sonnet-20250219", + "status": "active", + "pricing_config": null + } + ] + } + ] +} diff --git a/src/globals.ts b/src/globals.ts index 8ec98f4fb..af88e3861 100644 --- a/src/globals.ts +++ b/src/globals.ts @@ -243,3 +243,16 @@ export enum BatchEndpoints { COMPLETIONS = '/v1/completions', EMBEDDINGS = '/v1/embeddings', } + +export const AtomicOperations = { + GET: 'GET', + RESET: 'RESET', + INCREMENT: 'INCREMENT', + DECREMENT: 'DECREMENT', +}; + +export enum RateLimiterKeyTypes { + VIRTUAL_KEY = 'VIRTUAL_KEY', + API_KEY = 'API_KEY', + INTEGRATION_WORKSPACE = 'INTEGRATION_WORKSPACE', +} diff --git a/src/handlers/services/responseService.ts b/src/handlers/services/responseService.ts index 5c35e55d3..21146a9c8 100644 --- a/src/handlers/services/responseService.ts +++ b/src/handlers/services/responseService.ts @@ -1,5 +1,6 @@ // responseService.ts +import { getRuntimeKey } from 'hono/adapter'; import { HEADER_KEYS, POWERED_BY, RESPONSE_HEADER_KEYS } from '../../globals'; import { responseHandler } from '../responseHandlers'; import { HooksService } from './hooksService'; @@ -121,10 +122,11 @@ export class ResponseService { } // Remove headers directly - // const encoding = response.headers.get('content-encoding'); - // if (encoding?.includes('br') || getRuntimeKey() == 'node') { - // response.headers.delete('content-encoding'); - // } + // TODO: verify a workaround for node environments with brotli encoding + const encoding = response.headers.get('content-encoding'); + if (encoding?.includes('br') || getRuntimeKey() === 'node') { + response.headers.delete('content-encoding'); + } response.headers.delete('content-length'); // response.headers.delete('transfer-encoding'); diff --git a/src/index.ts b/src/index.ts index 594722d51..c18e33571 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,7 +8,7 @@ import { Context, Hono } from 'hono'; import { prettyJSON } from 'hono/pretty-json'; import { HTTPException } from 'hono/http-exception'; import { compress } from 'hono/compress'; -import { getRuntimeKey } from 'hono/adapter'; +import { env, getRuntimeKey } from 'hono/adapter'; // import { env } from 'hono/adapter' // Have to set this up for multi-environment deployment // Middlewares @@ -36,16 +36,33 @@ import { messagesHandler } from './handlers/messagesHandler'; // Config import conf from '../conf.json'; import modelResponsesHandler from './handlers/modelResponsesHandler'; +import { + createCacheBackendsLocal, + createCacheBackendsRedis, + createCacheBackendsCF, +} from './shared/services/cache'; // Create a new Hono server instance const app = new Hono(); +const runtime = getRuntimeKey(); + +// cache beackends will only get created during worker or app initialization depending on the runtime +if (getRuntimeKey() === 'workerd') { + app.use('*', (c: Context, next) => { + createCacheBackendsCF(env(c)); + return next(); + }); +} else if (getRuntimeKey() === 'node' && process.env.REDIS_CONNECTION_STRING) { + createCacheBackendsRedis(process.env.REDIS_CONNECTION_STRING); +} else { + createCacheBackendsLocal(); +} + /** * Middleware that conditionally applies compression middleware based on the runtime. * Compression is automatically handled for lagon and workerd runtimes * This check if its not any of the 2 and then applies the compress middleware to avoid double compression. */ - -const runtime = getRuntimeKey(); app.use('*', (c, next) => { const runtimesThatDontNeedCompression = ['lagon', 'workerd', 'node']; if (runtimesThatDontNeedCompression.includes(runtime)) { diff --git a/src/services/realtimeLlmEventParser.ts b/src/services/realtimeLlmEventParser.ts index 88415cc87..12432c2ca 100644 --- a/src/services/realtimeLlmEventParser.ts +++ b/src/services/realtimeLlmEventParser.ts @@ -1,4 +1,5 @@ import { Context } from 'hono'; +import { addBackgroundTask } from '../utils/misc'; export class RealtimeLlmEventParser { private sessionState: any; @@ -48,7 +49,8 @@ export class RealtimeLlmEventParser { this.sessionState.sessionDetails = { ...data.session }; const realtimeEventParser = c.get('realtimeEventParser'); if (realtimeEventParser) { - c.executionCtx.waitUntil( + addBackgroundTask( + c, realtimeEventParser( c, sessionOptions, @@ -69,7 +71,8 @@ export class RealtimeLlmEventParser { this.sessionState.sessionDetails = { ...data.session }; const realtimeEventParser = c.get('realtimeEventParser'); if (realtimeEventParser) { - c.executionCtx.waitUntil( + addBackgroundTask( + c, realtimeEventParser( c, sessionOptions, @@ -106,7 +109,8 @@ export class RealtimeLlmEventParser { const itemSequence = this.rebuildConversationSequence( this.sessionState.conversation.items ); - c.executionCtx.waitUntil( + addBackgroundTask( + c, realtimeEventParser( c, sessionOptions, @@ -128,7 +132,8 @@ export class RealtimeLlmEventParser { private handleError(c: Context, data: any, sessionOptions: any): void { const realtimeEventParser = c.get('realtimeEventParser'); if (realtimeEventParser) { - c.executionCtx.waitUntil( + addBackgroundTask( + c, realtimeEventParser(c, sessionOptions, {}, data, data.type) ); } diff --git a/src/shared/services/cache/backends/cloudflareKV.ts b/src/shared/services/cache/backends/cloudflareKV.ts new file mode 100644 index 000000000..820d461e9 --- /dev/null +++ b/src/shared/services/cache/backends/cloudflareKV.ts @@ -0,0 +1,228 @@ +/** + * @file src/services/cache/backends/cloudflareKV.ts + * Cloudflare KV cache backend implementation + */ + +import { CacheBackend, CacheEntry, CacheOptions, CacheStats } from '../types'; + +// Using console.log for now to avoid build issues +const logger = { + debug: (msg: string, ...args: any[]) => + console.debug(`[CloudflareKVCache] ${msg}`, ...args), + info: (msg: string, ...args: any[]) => + console.info(`[CloudflareKVCache] ${msg}`, ...args), + warn: (msg: string, ...args: any[]) => + console.warn(`[CloudflareKVCache] ${msg}`, ...args), + error: (msg: string, ...args: any[]) => + console.error(`[CloudflareKVCache] ${msg}`, ...args), +}; + +// Cloudflare KV client interface +interface ICloudflareKVClient { + get(key: string): Promise; + set(key: string, value: string, options?: CacheOptions): Promise; + del(key: string): Promise; + keys(prefix: string): Promise; +} + +export class CloudflareKVCacheBackend implements CacheBackend { + private client: ICloudflareKVClient; + private dbName: string; + + private stats: CacheStats = { + hits: 0, + misses: 0, + sets: 0, + deletes: 0, + size: 0, + expired: 0, + }; + + constructor(client: ICloudflareKVClient, dbName: string) { + this.client = client; + this.dbName = dbName; + } + + private getFullKey(key: string, namespace?: string): string { + return namespace + ? `${this.dbName}:${namespace}:${key}` + : `${this.dbName}:default:${key}`; + } + + private serializeEntry(entry: CacheEntry): string { + return JSON.stringify(entry); + } + + private deserializeEntry(data: string): CacheEntry { + return JSON.parse(data); + } + + async get( + key: string, + namespace?: string + ): Promise | null> { + try { + const fullKey = this.getFullKey(key, namespace); + const data = await this.client.get(fullKey); + + if (!data) { + this.stats.misses++; + return null; + } + + const entry = this.deserializeEntry(data); + + this.stats.hits++; + return entry; + } catch (error) { + logger.error('Redis get error:', error); + this.stats.misses++; + return null; + } + } + + async set( + key: string, + value: T, + options: CacheOptions = {} + ): Promise { + try { + const fullKey = this.getFullKey(key, options.namespace); + const now = Date.now(); + + const entry: CacheEntry = { + value, + createdAt: now, + expiresAt: options.ttl ? now + options.ttl : undefined, + metadata: options.metadata, + }; + + const serialized = this.serializeEntry(entry); + + this.client.set(fullKey, serialized, options); + + this.stats.sets++; + } catch (error) { + logger.error('Cloudflare KV set error:', error); + throw error; + } + } + + async delete(key: string, namespace?: string): Promise { + try { + const fullKey = this.getFullKey(key, namespace); + const deleted = await this.client.del(fullKey); + + if (deleted > 0) { + this.stats.deletes++; + return true; + } + + return false; + } catch (error) { + logger.error('Cloudflare KV delete error:', error); + return false; + } + } + + async clear(namespace?: string): Promise { + logger.debug('Cloudflare KV clear not implemented', namespace); + } + + async keys(namespace?: string): Promise { + try { + const prefix = namespace ? `cache:${namespace}:` : 'cache:default:'; + const fullKeys = await this.client.keys(prefix); + + return fullKeys.map((key) => key.substring(prefix.length)); + } catch (error) { + logger.error('Redis keys error:', error); + return []; + } + } + + async getStats(namespace?: string): Promise { + try { + const prefix = namespace ? `cache:${namespace}:` : 'cache:default:'; + const keys = await this.client.keys(prefix); + + return { + ...this.stats, + size: keys.length, + }; + } catch (error) { + logger.error('Redis getStats error:', error); + return { ...this.stats }; + } + } + + async has(key: string, namespace?: string): Promise { + logger.info('Cloudflare KV has not implemented', key, namespace); + return false; + } + + async cleanup(): Promise { + // Redis handles TTL automatically, so this is mostly a no-op + // We could scan for entries with manual expiration and clean them up + logger.debug('Redis cleanup - TTL handled automatically by Redis'); + } + + async close(): Promise { + logger.debug('Cloudflare KV close not implemented'); + } +} + +// Cloudflare KV client implementation +class CloudflareKVClient implements ICloudflareKVClient { + private KV: any; + + constructor(env: any, kvBindingName: string) { + this.KV = env[kvBindingName]; + } + + get = async (key: string): Promise => { + return await this.KV.get(key); + }; + + set = async ( + key: string, + value: string, + options?: CacheOptions + ): Promise => { + const kvOptions = { + expirationTtl: options?.ttl, + metadata: options?.metadata, + }; + try { + await this.KV.put(key, value, kvOptions); + return; + } catch (error) { + logger.error('Error setting key in Cloudflare KV:', error); + throw error; + } + }; + + del = async (key: string): Promise => { + try { + await this.KV.delete(key); + return 1; + } catch (error) { + logger.error('Error deleting key in Cloudflare KV:', error); + throw error; + } + }; + + keys = async (prefix: string): Promise => { + return await this.KV.list({ prefix }); + }; +} + +// Factory function to create Cloudflare KV backend +export function createCloudflareKVBackend( + env: any, + bindingName: string, + dbName: string +): CloudflareKVCacheBackend { + const client = new CloudflareKVClient(env, bindingName); + return new CloudflareKVCacheBackend(client, dbName); +} diff --git a/src/shared/services/cache/backends/file.ts b/src/shared/services/cache/backends/file.ts new file mode 100644 index 000000000..e517960ba --- /dev/null +++ b/src/shared/services/cache/backends/file.ts @@ -0,0 +1,321 @@ +/** + * @file src/services/cache/backends/file.ts + * File-based cache backend implementation + */ + +import { CacheBackend, CacheEntry, CacheOptions, CacheStats } from '../types'; +import * as fs from 'fs/promises'; +import * as path from 'path'; + +// Using console.log for now to avoid build issues +const logger = { + debug: (msg: string, ...args: any[]) => + console.debug(`[FileCache] ${msg}`, ...args), + info: (msg: string, ...args: any[]) => + console.info(`[FileCache] ${msg}`, ...args), + warn: (msg: string, ...args: any[]) => + console.warn(`[FileCache] ${msg}`, ...args), + error: (msg: string, ...args: any[]) => + console.error(`[FileCache] ${msg}`, ...args), +}; + +interface FileCacheData { + [namespace: string]: { + [key: string]: CacheEntry; + }; +} + +export class FileCacheBackend implements CacheBackend { + private cacheFile: string; + private data: FileCacheData = {}; + private saveTimer?: NodeJS.Timeout; + private cleanupInterval?: NodeJS.Timeout; + private loaded: boolean = false; + private loadPromise: Promise; + private stats: CacheStats = { + hits: 0, + misses: 0, + sets: 0, + deletes: 0, + size: 0, + expired: 0, + }; + private saveInterval: number; + constructor( + dataDir: string = 'data', + fileName: string = 'cache.json', + saveIntervalMs: number = 1000, + cleanupIntervalMs: number = 60000 + ) { + this.cacheFile = path.join(process.cwd(), dataDir, fileName); + this.saveInterval = saveIntervalMs; + this.loadPromise = this.loadCache(); + this.loadPromise.then(() => { + this.startCleanup(cleanupIntervalMs); + }); + } + + // Ensure cache is loaded before any operation + private async ensureLoaded(): Promise { + if (!this.loaded) { + await this.loadPromise; + } + } + + private async ensureDataDir(): Promise { + const dir = path.dirname(this.cacheFile); + try { + await fs.mkdir(dir, { recursive: true }); + } catch (error) { + logger.error('Failed to create cache directory:', error); + } + } + + private async loadCache(): Promise { + try { + const content = await fs.readFile(this.cacheFile, 'utf-8'); + this.data = JSON.parse(content); + this.updateStats(); + logger.debug('Loaded cache from disk', this.cacheFile); + this.loaded = true; + } catch (error) { + // File doesn't exist or is invalid, start with empty cache + this.data = {}; + logger.debug('Starting with empty cache'); + } + } + + private async saveCache(): Promise { + try { + await this.ensureDataDir(); + await fs.writeFile(this.cacheFile, JSON.stringify(this.data, null, 2)); + logger.debug('Saved cache to disk'); + } catch (error) { + logger.error('Failed to save cache:', error); + } + } + + private scheduleSave(): void { + if (this.saveTimer) { + clearTimeout(this.saveTimer); + } + + this.saveTimer = setTimeout(() => { + this.saveCache(); + this.saveTimer = undefined; + }, this.saveInterval); + } + + private startCleanup(intervalMs: number): void { + this.cleanupInterval = setInterval(() => { + this.cleanup(); + }, intervalMs); + } + + private isExpired(entry: CacheEntry): boolean { + return entry.expiresAt !== undefined && entry.expiresAt <= Date.now(); + } + + private updateStats(): void { + let totalSize = 0; + let totalExpired = 0; + + for (const namespace of Object.values(this.data)) { + for (const entry of Object.values(namespace)) { + totalSize++; + if (this.isExpired(entry)) { + totalExpired++; + } + } + } + + this.stats.size = totalSize; + this.stats.expired = totalExpired; + } + + private getNamespaceData( + namespace: string = 'default' + ): Record { + if (!this.data[namespace]) { + this.data[namespace] = {}; + } + return this.data[namespace]; + } + + async get( + key: string, + namespace?: string + ): Promise | null> { + await this.ensureLoaded(); // Wait for load to complete + + const namespaceData = this.getNamespaceData(namespace); + const entry = namespaceData[key]; + + if (!entry) { + this.stats.misses++; + return null; + } + + if (this.isExpired(entry)) { + delete namespaceData[key]; + this.stats.expired++; + this.stats.misses++; + this.scheduleSave(); + return null; + } + + this.stats.hits++; + return entry as CacheEntry; + } + + async set( + key: string, + value: T, + options: CacheOptions = {} + ): Promise { + await this.ensureLoaded(); // Wait for load to complete + + const namespace = options.namespace || 'default'; + const namespaceData = this.getNamespaceData(namespace); + const now = Date.now(); + + const entry: CacheEntry = { + value, + createdAt: now, + expiresAt: options.ttl ? now + options.ttl : undefined, + metadata: options.metadata, + }; + + namespaceData[key] = entry; + this.stats.sets++; + this.updateStats(); + this.scheduleSave(); + } + + async delete(key: string, namespace?: string): Promise { + const namespaceData = this.getNamespaceData(namespace); + const existed = key in namespaceData; + + if (existed) { + delete namespaceData[key]; + this.stats.deletes++; + this.updateStats(); + this.scheduleSave(); + } + + return existed; + } + + async clear(namespace?: string): Promise { + if (namespace) { + const namespaceData = this.getNamespaceData(namespace); + const count = Object.keys(namespaceData).length; + this.data[namespace] = {}; + this.stats.deletes += count; + } else { + const totalCount = Object.values(this.data).reduce( + (sum, ns) => sum + Object.keys(ns).length, + 0 + ); + this.data = {}; + this.stats.deletes += totalCount; + } + + this.updateStats(); + this.scheduleSave(); + } + + async has(key: string, namespace?: string): Promise { + const namespaceData = this.getNamespaceData(namespace); + const entry = namespaceData[key]; + + if (!entry) return false; + + if (this.isExpired(entry)) { + delete namespaceData[key]; + this.stats.expired++; + this.scheduleSave(); + return false; + } + + return true; + } + + async keys(namespace?: string): Promise { + if (namespace) { + const namespaceData = this.getNamespaceData(namespace); + return Object.keys(namespaceData); + } + + const allKeys: string[] = []; + for (const namespaceData of Object.values(this.data)) { + allKeys.push(...Object.keys(namespaceData)); + } + return allKeys; + } + + async getStats(namespace?: string): Promise { + if (namespace) { + const namespaceData = this.getNamespaceData(namespace); + const keys = Object.keys(namespaceData); + let expired = 0; + + for (const key of keys) { + const entry = namespaceData[key]; + if (this.isExpired(entry)) { + expired++; + } + } + + return { + ...this.stats, + size: keys.length, + expired, + }; + } + + this.updateStats(); + return { ...this.stats }; + } + + async cleanup(): Promise { + let expiredCount = 0; + let hasChanges = false; + + for (const [, namespaceData] of Object.entries(this.data)) { + for (const [key, entry] of Object.entries(namespaceData)) { + if (this.isExpired(entry)) { + delete namespaceData[key]; + expiredCount++; + hasChanges = true; + } + } + } + + if (hasChanges) { + this.stats.expired += expiredCount; + this.updateStats(); + this.scheduleSave(); + logger.debug(`Cleaned up ${expiredCount} expired entries`); + } + } + + // Add method to check if ready + async waitForReady(): Promise { + await this.loadPromise; + } + + async close(): Promise { + if (this.saveTimer) { + clearTimeout(this.saveTimer); + await this.saveCache(); // Final save + } + + if (this.cleanupInterval) { + clearInterval(this.cleanupInterval); + this.cleanupInterval = undefined; + } + + logger.debug('File cache backend closed'); + } +} diff --git a/src/shared/services/cache/backends/memory.ts b/src/shared/services/cache/backends/memory.ts new file mode 100644 index 000000000..f1e225da4 --- /dev/null +++ b/src/shared/services/cache/backends/memory.ts @@ -0,0 +1,220 @@ +/** + * @file src/services/cache/backends/memory.ts + * In-memory cache backend implementation + */ + +import { CacheBackend, CacheEntry, CacheOptions, CacheStats } from '../types'; +// Using console.log for now to avoid build issues +const logger = { + debug: (msg: string, ...args: any[]) => + console.debug(`[MemoryCache] ${msg}`, ...args), + info: (msg: string, ...args: any[]) => + console.info(`[MemoryCache] ${msg}`, ...args), + warn: (msg: string, ...args: any[]) => + console.warn(`[MemoryCache] ${msg}`, ...args), + error: (msg: string, ...args: any[]) => + console.error(`[MemoryCache] ${msg}`, ...args), +}; + +export class MemoryCacheBackend implements CacheBackend { + private cache = new Map(); + private stats: CacheStats = { + hits: 0, + misses: 0, + sets: 0, + deletes: 0, + size: 0, + expired: 0, + }; + private cleanupInterval?: NodeJS.Timeout; + private maxSize: number; + + constructor(maxSize: number = 10000, cleanupIntervalMs: number = 60000) { + this.maxSize = maxSize; + this.startCleanup(cleanupIntervalMs); + } + + private startCleanup(intervalMs: number): void { + this.cleanupInterval = setInterval(() => { + this.cleanup(); + }, intervalMs); + } + + private getFullKey(key: string, namespace?: string): string { + return namespace ? `${namespace}:${key}` : key; + } + + private isExpired(entry: CacheEntry): boolean { + return entry.expiresAt !== undefined && entry.expiresAt <= Date.now(); + } + + private evictIfNeeded(): void { + if (this.cache.size >= this.maxSize) { + // Simple LRU: remove oldest entries + const entries = Array.from(this.cache.entries()); + entries.sort((a, b) => a[1].createdAt - b[1].createdAt); + + const toRemove = Math.floor(this.maxSize * 0.1); // Remove 10% + for (let i = 0; i < toRemove && i < entries.length; i++) { + this.cache.delete(entries[i][0]); + } + + logger.debug(`Evicted ${toRemove} entries due to size limit`); + } + } + + async get( + key: string, + namespace?: string + ): Promise | null> { + const fullKey = this.getFullKey(key, namespace); + const entry = this.cache.get(fullKey); + + if (!entry) { + this.stats.misses++; + return null; + } + + if (this.isExpired(entry)) { + this.cache.delete(fullKey); + this.stats.expired++; + this.stats.misses++; + return null; + } + + this.stats.hits++; + return entry as CacheEntry; + } + + async set( + key: string, + value: T, + options: CacheOptions = {} + ): Promise { + const fullKey = this.getFullKey(key, options.namespace); + const now = Date.now(); + + const entry: CacheEntry = { + value, + createdAt: now, + expiresAt: options.ttl ? now + options.ttl : undefined, + metadata: options.metadata, + }; + + this.evictIfNeeded(); + this.cache.set(fullKey, entry); + this.stats.sets++; + this.stats.size = this.cache.size; + } + + async delete(key: string, namespace?: string): Promise { + const fullKey = this.getFullKey(key, namespace); + const deleted = this.cache.delete(fullKey); + + if (deleted) { + this.stats.deletes++; + this.stats.size = this.cache.size; + } + + return deleted; + } + + async clear(namespace?: string): Promise { + if (namespace) { + const prefix = `${namespace}:`; + const keysToDelete = Array.from(this.cache.keys()).filter((key) => + key.startsWith(prefix) + ); + + for (const key of keysToDelete) { + this.cache.delete(key); + } + + this.stats.deletes += keysToDelete.length; + } else { + this.stats.deletes += this.cache.size; + this.cache.clear(); + } + + this.stats.size = this.cache.size; + } + + async has(key: string, namespace?: string): Promise { + const fullKey = this.getFullKey(key, namespace); + const entry = this.cache.get(fullKey); + + if (!entry) return false; + + if (this.isExpired(entry)) { + this.cache.delete(fullKey); + this.stats.expired++; + return false; + } + + return true; + } + + async keys(namespace?: string): Promise { + const allKeys = Array.from(this.cache.keys()); + + if (namespace) { + const prefix = `${namespace}:`; + return allKeys + .filter((key) => key.startsWith(prefix)) + .map((key) => key.substring(prefix.length)); + } + + return allKeys; + } + + async getStats(namespace?: string): Promise { + if (namespace) { + const prefix = `${namespace}:`; + const namespaceKeys = Array.from(this.cache.keys()).filter((key) => + key.startsWith(prefix) + ); + + let expired = 0; + for (const key of namespaceKeys) { + const entry = this.cache.get(key); + if (entry && this.isExpired(entry)) { + expired++; + } + } + + return { + ...this.stats, + size: namespaceKeys.length, + expired, + }; + } + + return { ...this.stats }; + } + + async cleanup(): Promise { + let expiredCount = 0; + + for (const [key, entry] of this.cache.entries()) { + if (this.isExpired(entry)) { + this.cache.delete(key); + expiredCount++; + } + } + + if (expiredCount > 0) { + this.stats.expired += expiredCount; + this.stats.size = this.cache.size; + logger.debug(`Cleaned up ${expiredCount} expired entries`); + } + } + + async close(): Promise { + if (this.cleanupInterval) { + clearInterval(this.cleanupInterval); + this.cleanupInterval = undefined; + } + this.cache.clear(); + logger.debug('Memory cache backend closed'); + } +} diff --git a/src/shared/services/cache/backends/redis.ts b/src/shared/services/cache/backends/redis.ts new file mode 100644 index 000000000..732104b71 --- /dev/null +++ b/src/shared/services/cache/backends/redis.ts @@ -0,0 +1,252 @@ +/** + * @file src/services/cache/backends/redis.ts + * Redis cache backend implementation + */ +import Redis from 'ioredis'; + +import { CacheBackend, CacheEntry, CacheOptions, CacheStats } from '../types'; + +// Using console.log for now to avoid build issues +const logger = { + debug: (msg: string, ...args: any[]) => + console.debug(`[RedisCache] ${msg}`, ...args), + info: (msg: string, ...args: any[]) => + console.info(`[RedisCache] ${msg}`, ...args), + warn: (msg: string, ...args: any[]) => + console.warn(`[RedisCache] ${msg}`, ...args), + error: (msg: string, ...args: any[]) => + console.error(`[RedisCache] ${msg}`, ...args), +}; + +// Redis client interface matching ioredis +interface RedisClient { + get(key: string): Promise; + set( + key: string, + value: string, + expiryMode?: string | any, + time?: number | string + ): Promise<'OK' | null>; + del(...keys: string[]): Promise; + exists(...keys: string[]): Promise; + keys(pattern: string): Promise; + flushdb(): Promise<'OK'>; + quit(): Promise<'OK'>; +} + +export class RedisCacheBackend implements CacheBackend { + private client: RedisClient; + private dbName: string; + + private stats: CacheStats = { + hits: 0, + misses: 0, + sets: 0, + deletes: 0, + size: 0, + expired: 0, + }; + + constructor(client: RedisClient, dbName: string) { + this.client = client; + this.dbName = dbName; + } + + private getFullKey(key: string, namespace?: string): string { + return namespace + ? `${this.dbName}:${namespace}:${key}` + : `${this.dbName}:default:${key}`; + } + + private serializeEntry(entry: CacheEntry): string { + return JSON.stringify(entry); + } + + private deserializeEntry(data: string): CacheEntry { + return JSON.parse(data); + } + + private isExpired(entry: CacheEntry): boolean { + return entry.expiresAt !== undefined && entry.expiresAt <= Date.now(); + } + + async get( + key: string, + namespace?: string + ): Promise | null> { + try { + const fullKey = this.getFullKey(key, namespace); + const data = await this.client.get(fullKey); + + if (!data) { + this.stats.misses++; + return null; + } + + const entry = this.deserializeEntry(data); + + // Double-check expiration (Redis TTL should handle this, but just in case) + if (this.isExpired(entry)) { + await this.client.del(fullKey); + this.stats.expired++; + this.stats.misses++; + return null; + } + + this.stats.hits++; + return entry; + } catch (error) { + logger.error('Redis get error:', error); + this.stats.misses++; + return null; + } + } + + async set( + key: string, + value: T, + options: CacheOptions = {} + ): Promise { + try { + const fullKey = this.getFullKey(key, options.namespace); + const now = Date.now(); + + const entry: CacheEntry = { + value, + createdAt: now, + expiresAt: options.ttl ? now + options.ttl : undefined, + metadata: options.metadata, + }; + + const serialized = this.serializeEntry(entry); + + if (options.ttl) { + // Set with TTL in seconds + const ttlSeconds = Math.ceil(options.ttl / 1000); + await this.client.set(fullKey, serialized, 'EX', ttlSeconds); + } else { + await this.client.set(fullKey, serialized); + } + + this.stats.sets++; + } catch (error) { + logger.error('Redis set error:', error); + throw error; + } + } + + async delete(key: string, namespace?: string): Promise { + try { + const fullKey = this.getFullKey(key, namespace); + const deleted = await this.client.del(fullKey); + + if (deleted > 0) { + this.stats.deletes++; + return true; + } + + return false; + } catch (error) { + logger.error('Redis delete error:', error); + return false; + } + } + + async clear(namespace?: string): Promise { + try { + const pattern = namespace + ? `${this.dbName}:${namespace}:*` + : `${this.dbName}:*`; + const keys = await this.client.keys(pattern); + + if (keys.length > 0) { + // Use single del call with spread operator for better performance + await this.client.del(...keys); + this.stats.deletes += keys.length; + } + } catch (error) { + logger.error('Redis clear error:', error); + throw error; + } + } + + async has(key: string, namespace?: string): Promise { + try { + const fullKey = this.getFullKey(key, namespace); + const exists = await this.client.exists(fullKey); + return exists > 0; + } catch (error) { + logger.error('Redis has error:', error); + return false; + } + } + + async keys(namespace?: string): Promise { + try { + const pattern = namespace + ? `${this.dbName}:${namespace}:*` + : `${this.dbName}:default:*`; + const fullKeys = await this.client.keys(pattern); + + // Extract the actual key part (remove the prefix) + const prefix = namespace + ? `${this.dbName}:${namespace}:` + : `${this.dbName}:default:`; + return fullKeys.map((key) => key.substring(prefix.length)); + } catch (error) { + logger.error('Redis keys error:', error); + return []; + } + } + + async getStats(namespace?: string): Promise { + try { + const pattern = namespace + ? `${this.dbName}:${namespace}:*` + : `${this.dbName}:*`; + const keys = await this.client.keys(pattern); + + return { + ...this.stats, + size: keys.length, + }; + } catch (error) { + logger.error('Redis getStats error:', error); + return { ...this.stats }; + } + } + + async cleanup(): Promise { + // Redis handles TTL automatically, so this is mostly a no-op + // We could scan for entries with manual expiration and clean them up + logger.debug('Redis cleanup - TTL handled automatically by Redis'); + } + + async close(): Promise { + try { + await this.client.quit(); + logger.debug('Redis cache backend closed'); + } catch (error) { + logger.error('Error closing Redis connection:', error); + } + } +} + +// Factory function to create Redis backend with ioredis +export function createRedisBackend( + redisUrl: string, + options?: any +): RedisCacheBackend { + // Extract dbName from options or use 'cache' as default + const dbName = options?.dbName || 'cache'; + + // Create ioredis client with URL and any additional options + // ioredis supports Redis URL format: redis://[username:password@]host[:port][/db] + const client = new Redis(redisUrl, { + ...options, + // Remove dbName from options as it's not an ioredis option + dbName: undefined, + }); + + return new RedisCacheBackend(client as RedisClient, dbName); +} diff --git a/src/shared/services/cache/index.ts b/src/shared/services/cache/index.ts new file mode 100644 index 000000000..a229d2b8b --- /dev/null +++ b/src/shared/services/cache/index.ts @@ -0,0 +1,486 @@ +/** + * @file src/services/cache/index.ts + * Unified cache service with pluggable backends + */ + +import { + CacheBackend, + CacheEntry, + CacheOptions, + CacheStats, + CacheConfig, +} from './types'; +import { MemoryCacheBackend } from './backends/memory'; +import { FileCacheBackend } from './backends/file'; +import { createRedisBackend } from './backends/redis'; +import { createCloudflareKVBackend } from './backends/cloudflareKV'; +// Using console.log for now to avoid build issues +const logger = { + debug: (msg: string, ...args: any[]) => + console.debug(`[CacheService] ${msg}`, ...args), + info: (msg: string, ...args: any[]) => + console.info(`[CacheService] ${msg}`, ...args), + warn: (msg: string, ...args: any[]) => + console.warn(`[CacheService] ${msg}`, ...args), + error: (msg: string, ...args: any[]) => + console.error(`[CacheService] ${msg}`, ...args), +}; + +const MS = { + '1_MINUTE': 1 * 60 * 1000, + '5_MINUTES': 5 * 60 * 1000, + '10_MINUTES': 10 * 60 * 1000, + '30_MINUTES': 30 * 60 * 1000, + '1_HOUR': 60 * 60 * 1000, + '6_HOURS': 6 * 60 * 60 * 1000, + '12_HOURS': 12 * 60 * 60 * 1000, + '1_DAY': 24 * 60 * 60 * 1000, + '7_DAYS': 7 * 24 * 60 * 60 * 1000, + '30_DAYS': 30 * 24 * 60 * 60 * 1000, +}; + +export class CacheService { + private backend: CacheBackend; + private defaultTtl?: number; + + constructor(config: CacheConfig) { + this.defaultTtl = config.defaultTtl; + this.backend = this.createBackend(config); + } + + private createBackend(config: CacheConfig): CacheBackend { + switch (config.backend) { + case 'memory': + return new MemoryCacheBackend(config.maxSize, config.cleanupInterval); + + case 'file': + return new FileCacheBackend( + config.dataDir, + config.fileName, + config.saveInterval, + config.cleanupInterval + ); + + case 'redis': + if (!config.redisUrl) { + throw new Error('Redis URL is required for Redis backend'); + } + return createRedisBackend(config.redisUrl, { + ...config.redisOptions, + dbName: config.dbName || 'cache', + }); + + case 'cloudflareKV': + if (!config.kvBindingName || !config.dbName) { + throw new Error( + 'Cloudflare KV binding name and db name are required for Cloudflare KV backend' + ); + } + return createCloudflareKVBackend( + config.env, + config.kvBindingName, + config.dbName + ); + + default: + throw new Error(`Unsupported cache backend: ${config.backend}`); + } + } + + /** + * Get a value from the cache + */ + async get(key: string, namespace?: string): Promise { + const entry = await this.backend.get(key, namespace); + return entry ? entry.value : null; + } + + /** + * Get the full cache entry (with metadata) + */ + async getEntry( + key: string, + namespace?: string + ): Promise | null> { + return this.backend.get(key, namespace); + } + + /** + * Set a value in the cache + */ + async set( + key: string, + value: T, + options: CacheOptions = {} + ): Promise { + const finalOptions = { + ...options, + ttl: options.ttl ?? this.defaultTtl, + }; + + await this.backend.set(key, value, finalOptions); + } + + /** + * Set a value with TTL in seconds (convenience method) + */ + async setWithTtl( + key: string, + value: T, + ttlSeconds: number, + namespace?: string + ): Promise { + await this.set(key, value, { + ttl: ttlSeconds * 1000, + namespace, + }); + } + + /** + * Delete a value from the cache + */ + async delete(key: string, namespace?: string): Promise { + return this.backend.delete(key, namespace); + } + + /** + * Check if a key exists in the cache + */ + async has(key: string, namespace?: string): Promise { + return this.backend.has(key, namespace); + } + + /** + * Get all keys in a namespace + */ + async keys(namespace?: string): Promise { + return this.backend.keys(namespace); + } + + /** + * Clear all entries in a namespace (or all entries if no namespace) + */ + async clear(namespace?: string): Promise { + await this.backend.clear(namespace); + } + + /** + * Get cache statistics + */ + async getStats(namespace?: string): Promise { + return this.backend.getStats(namespace); + } + + /** + * Manually trigger cleanup of expired entries + */ + async cleanup(): Promise { + await this.backend.cleanup(); + } + + /** + * Wait for the backend to be ready + */ + async waitForReady(): Promise { + if ('waitForReady' in this.backend) { + await (this.backend as any).waitForReady(); + } + } + + /** + * Close the cache and cleanup resources + */ + async close(): Promise { + await this.backend.close(); + } + + /** + * Get or set pattern - get value, or compute and cache it if not found + */ + async getOrSet( + key: string, + factory: () => Promise | T, + options: CacheOptions = {} + ): Promise { + const existing = await this.get(key, options.namespace); + if (existing !== null) { + return existing; + } + + const value = await factory(); + await this.set(key, value, options); + return value; + } + + /** + * Increment a numeric value (atomic operation for supported backends) + */ + async increment( + key: string, + delta: number = 1, + options: CacheOptions = {} + ): Promise { + // For backends that don't support atomic increment, we simulate it + const current = (await this.get(key, options.namespace)) || 0; + const newValue = current + delta; + await this.set(key, newValue, options); + return newValue; + } + + /** + * Set multiple values at once + */ + async setMany( + entries: Array<{ key: string; value: T; options?: CacheOptions }>, + defaultOptions: CacheOptions = {} + ): Promise { + const promises = entries.map(({ key, value, options }) => + this.set(key, value, { ...defaultOptions, ...options }) + ); + await Promise.all(promises); + } + + /** + * Get multiple values at once + */ + async getMany( + keys: string[], + namespace?: string + ): Promise> { + const promises = keys.map(async (key) => ({ + key, + value: await this.get(key, namespace), + })); + return Promise.all(promises); + } +} + +// Default cache instances for different use cases +let defaultCache: CacheService | null = null; +let tokenCache: CacheService | null = null; +let sessionCache: CacheService | null = null; +let configCache: CacheService | null = null; +let oauthStore: CacheService | null = null; +let mcpServersCache: CacheService | null = null; +let apiRateLimiterCache: CacheService | null = null; +/** + * Get or create the default cache instance + */ +export function getDefaultCache(): CacheService { + if (!defaultCache) { + throw new Error('Default cache instance not found'); + } + return defaultCache; +} + +/** + * Get or create the token cache instance + */ +export function getTokenCache(): CacheService { + if (!tokenCache) { + throw new Error('Token cache instance not found'); + } + return tokenCache; +} + +/** + * Get or create the session cache instance + */ +export function getSessionCache(): CacheService { + if (!sessionCache) { + throw new Error('Session cache instance not found'); + } + return sessionCache; +} + +/** + * Get or create the token introspection cache instance + */ +export function getTokenIntrospectionCache(): CacheService { + // Use the same cache as tokens, just different namespace + return getTokenCache(); +} + +/** + * Get or create the config cache instance + */ +export function getConfigCache(): CacheService { + if (!configCache) { + throw new Error('Config cache instance not found'); + } + return configCache; +} + +/** + * Get or create the oauth store cache instance + */ +export function getOauthStore(): CacheService { + if (!oauthStore) { + throw new Error('Oauth store cache instance not found'); + } + return oauthStore; +} + +export function getMcpServersCache(): CacheService { + if (!mcpServersCache) { + throw new Error('Mcp servers cache instance not found'); + } + return mcpServersCache; +} + +/** + * Initialize cache with custom configuration + */ +export function initializeCache(config: CacheConfig): CacheService { + return new CacheService(config); +} + +export async function createCacheBackendsLocal(): Promise { + defaultCache = new CacheService({ + backend: 'memory', + defaultTtl: MS['5_MINUTES'], + cleanupInterval: MS['5_MINUTES'], + maxSize: 1000, + }); + + tokenCache = new CacheService({ + backend: 'memory', + defaultTtl: MS['5_MINUTES'], + saveInterval: 1000, // 1 second + cleanupInterval: MS['5_MINUTES'], + maxSize: 1000, + }); + + sessionCache = new CacheService({ + backend: 'file', + dataDir: 'data', + fileName: 'sessions-cache.json', + defaultTtl: MS['30_MINUTES'], + saveInterval: 1000, // 1 second + cleanupInterval: MS['5_MINUTES'], + }); + await sessionCache.waitForReady(); + + configCache = new CacheService({ + backend: 'memory', + defaultTtl: MS['30_DAYS'], + cleanupInterval: MS['5_MINUTES'], + maxSize: 100, + }); + + oauthStore = new CacheService({ + backend: 'file', + dataDir: 'data', + fileName: 'oauth-store.json', + saveInterval: 1000, // 1 second + cleanupInterval: MS['10_MINUTES'], + }); + await oauthStore.waitForReady(); + + mcpServersCache = new CacheService({ + backend: 'file', + dataDir: 'data', + fileName: 'mcp-servers-auth.json', + saveInterval: 1000, // 5 seconds + cleanupInterval: MS['5_MINUTES'], + }); + await mcpServersCache.waitForReady(); +} + +export function createCacheBackendsRedis(redisUrl: string): void { + logger.info('Creating cache backends with Redis', redisUrl); + let commonOptions: CacheConfig = { + backend: 'redis', + redisUrl: redisUrl, + defaultTtl: MS['5_MINUTES'], + cleanupInterval: MS['5_MINUTES'], + maxSize: 1000, + }; + + defaultCache = new CacheService({ + ...commonOptions, + dbName: 'default', + }); + + tokenCache = new CacheService({ + backend: 'memory', + defaultTtl: MS['1_MINUTE'], + cleanupInterval: MS['1_MINUTE'], + maxSize: 1000, + }); + + sessionCache = new CacheService({ + ...commonOptions, + dbName: 'session', + }); + + configCache = new CacheService({ + ...commonOptions, + dbName: 'config', + defaultTtl: undefined, + }); + + oauthStore = new CacheService({ + ...commonOptions, + dbName: 'oauth', + defaultTtl: undefined, + }); + + mcpServersCache = new CacheService({ + ...commonOptions, + dbName: 'mcp', + defaultTtl: undefined, + }); +} + +export function createCacheBackendsCF(env: any): void { + let commonOptions: CacheConfig = { + backend: 'cloudflareKV', + env: env, + kvBindingName: 'KV_STORE', + defaultTtl: MS['5_MINUTES'], + }; + defaultCache = new CacheService({ + ...commonOptions, + dbName: 'default', + }); + + tokenCache = new CacheService({ + ...commonOptions, + dbName: 'token', + defaultTtl: MS['10_MINUTES'], + }); + + sessionCache = new CacheService({ + ...commonOptions, + dbName: 'session', + }); + + configCache = new CacheService({ + ...commonOptions, + dbName: 'config', + defaultTtl: MS['30_DAYS'], + }); + + oauthStore = new CacheService({ + ...commonOptions, + dbName: 'oauth', + defaultTtl: undefined, + }); + + mcpServersCache = new CacheService({ + ...commonOptions, + dbName: 'mcp', + defaultTtl: undefined, + }); + + apiRateLimiterCache = new CacheService({ + ...commonOptions, + kvBindingName: 'API_RATE_LIMITER', + dbName: 'api-rate-limiter', + defaultTtl: undefined, + }); +} + +// Re-export types for convenience +export * from './types'; diff --git a/src/shared/services/cache/types.ts b/src/shared/services/cache/types.ts new file mode 100644 index 000000000..8875572bc --- /dev/null +++ b/src/shared/services/cache/types.ts @@ -0,0 +1,57 @@ +/** + * @file src/services/cache/types.ts + * Type definitions for the unified cache system + */ + +export interface CacheEntry { + value: T; + expiresAt?: number; + createdAt: number; + metadata?: Record; +} + +export interface CacheOptions { + ttl?: number; // Time to live in milliseconds + namespace?: string; // Cache namespace for organization + metadata?: Record; // Additional metadata +} + +export interface CacheStats { + hits: number; + misses: number; + sets: number; + deletes: number; + size: number; + expired: number; +} + +export interface CacheBackend { + get(key: string, namespace?: string): Promise | null>; + set(key: string, value: T, options?: CacheOptions): Promise; + delete(key: string, namespace?: string): Promise; + clear(namespace?: string): Promise; + has(key: string, namespace?: string): Promise; + keys(namespace?: string): Promise; + getStats(namespace?: string): Promise; + cleanup(): Promise; // Remove expired entries + close(): Promise; // Cleanup resources +} + +export interface CacheConfig { + backend: 'memory' | 'file' | 'redis' | 'cloudflareKV'; + defaultTtl?: number; // Default TTL in milliseconds + cleanupInterval?: number; // Cleanup interval in milliseconds + // File backend options + dataDir?: string; + fileName?: string; + saveInterval?: number; // Debounce save interval + // Redis backend options + redisUrl?: string; + redisOptions?: any; + // Memory backend options + maxSize?: number; // Maximum number of entries + // Cloudflare KV backend options + env?: any; + kvBindingName?: string; + dbName?: string; +} diff --git a/src/shared/services/cache/utils/rateLimiter.ts b/src/shared/services/cache/utils/rateLimiter.ts new file mode 100644 index 000000000..b6447702c --- /dev/null +++ b/src/shared/services/cache/utils/rateLimiter.ts @@ -0,0 +1,182 @@ +import { Redis, Cluster } from 'ioredis'; +import { RateLimiterKeyTypes } from '../../../../globals'; + +const RATE_LIMIT_LUA = ` +local tokensKey = KEYS[1] +local refillKey = KEYS[2] + +local capacity = tonumber(ARGV[1]) +local windowSize = tonumber(ARGV[2]) +local units = tonumber(ARGV[3]) +local now = tonumber(ARGV[4]) +local ttl = tonumber(ARGV[5]) +local consume = tonumber(ARGV[6]) -- 1 = consume, 0 = check only + +-- Reject invalid input +if units <= 0 or capacity <= 0 or windowSize <= 0 then + return {0, -1, -1} +end + +local lastRefill = tonumber(redis.call("GET", refillKey) or "0") +local tokens = tonumber(redis.call("GET", tokensKey) or "-1") + +local tokensModified = false +local refillModified = false + +-- Initialization +if tokens == -1 then + tokens = capacity + tokensModified = true +end + +if lastRefill == 0 then + lastRefill = now + refillModified = true +end + +-- Refill logic +local elapsed = now - lastRefill +if elapsed > 0 then + local rate = capacity / windowSize + local tokensToAdd = math.floor(elapsed * rate) + if tokensToAdd > 0 then + tokens = math.min(tokens + tokensToAdd, capacity) + lastRefill = now -- simpler and avoids drift + tokensModified = true + refillModified = true + end +end + +-- Consume logic +local allowed = 0 +local waitTime = 0 +local currentTokens = tokens + +if tokens >= units then + allowed = 1 + if consume == 1 then + tokens = tokens - units + tokensModified = true + end +else + local needed = units - currentTokens + local rate = capacity / windowSize + waitTime = (rate > 0) and math.floor(needed / rate) or -1 +end + +-- Save changes +if tokensModified then + redis.call("SET", tokensKey, tokens, "PX", ttl) +end + +if refillModified then + redis.call("SET", refillKey, lastRefill, "PX", ttl) +end + +return {allowed, waitTime, currentTokens} +`; + +class RedisRateLimiter { + private redis: Redis | Cluster; + private capacity: number; + private windowSize: number; + private tokensKey: string; + private lastRefillKey: string; + private keyTTL: number; + private scriptSha: string | null = null; // To store the SHA1 hash of the script + private keyType: RateLimiterKeyTypes; + private key: string; + + constructor( + redisClient: Redis | Cluster, + capacity: number, + windowSize: number, + key: string, + keyType: RateLimiterKeyTypes, + ttlFactor: number = 3 // multiplier for TTL + ) { + this.redis = redisClient; + const tag = `{rate:${key}}`; // ensures same hash slot + this.capacity = capacity; + this.windowSize = windowSize; + this.tokensKey = `${tag}:tokens`; + this.lastRefillKey = `${tag}:lastRefill`; + this.keyTTL = windowSize * ttlFactor; // dynamic TTL + this.keyType = keyType; + this.key = key; + } + + // Helper to load script if not already loaded and return SHA + private async loadOrGetScriptSha(): Promise { + if (this.scriptSha) { + return this.scriptSha; + } + // Load the script into Redis and get its SHA1 hash + const shaString: any = await this.redis.script('LOAD', RATE_LIMIT_LUA); + this.scriptSha = shaString; + return shaString; + } + + private async executeScript(keys: string[], args: string[]): Promise { + // Get SHA (loads script if not already loaded on current client) + const sha = await this.loadOrGetScriptSha(); + + try { + return await this.redis.evalsha(sha, keys.length, ...keys, ...args); + } catch (error: any) { + if (error.message.includes('NOSCRIPT')) { + // Script not loaded on target node - load it and retry with same SHA + await this.redis.script('LOAD', RATE_LIMIT_LUA); + return await this.redis.evalsha(sha, keys.length, ...keys, ...args); + } + throw error; + } + } + + async checkRateLimit( + units: number, + consumeTokens: boolean = true // Default to true to consume tokens + ): Promise<{ + keyType: RateLimiterKeyTypes; + key: string; + allowed: boolean; + waitTime: number; + currentTokens: number; + }> { + const now = Date.now(); + // Get the SHA, loading the script into Redis if this is the first time + const resp: any = await this.executeScript( + [this.tokensKey, this.lastRefillKey], + [ + this.capacity.toString(), + this.windowSize.toString(), + units.toString(), + now.toString(), + this.keyTTL.toString(), + consumeTokens ? '1' : '0', // Pass consume flag to Lua script + ] + ); + const [allowed, waitTime, currentTokens] = resp; + return { + keyType: this.keyType, + key: this.key, + allowed: allowed === 1, + waitTime: Number(waitTime), + currentTokens: Number(currentTokens), // Return current tokens + }; + } + + async getToken(): Promise { + return this.redis.get(this.tokensKey); + } + + async decrementToken( + units: number + ): Promise<{ allowed: boolean; waitTime: number }> { + // Call checkRateLimit ensuring tokens are consumed + const { allowed, waitTime } = await this.checkRateLimit(units, true); + return { allowed, waitTime }; + } +} + +export default RedisRateLimiter; diff --git a/src/shared/utils/logger.ts b/src/shared/utils/logger.ts new file mode 100644 index 000000000..3ad80ee63 --- /dev/null +++ b/src/shared/utils/logger.ts @@ -0,0 +1,128 @@ +/** + * @file src/utils/logger.ts + * Configurable logger utility for MCP Gateway + */ + +export enum LogLevel { + ERROR = 0, + CRITICAL = 1, // New level for critical information + WARN = 2, + INFO = 3, + DEBUG = 4, +} + +export interface LoggerConfig { + level: LogLevel; + prefix?: string; + timestamp?: boolean; + colors?: boolean; +} + +class Logger { + private config: LoggerConfig; + private colors = { + error: '\x1b[31m', // red + critical: '\x1b[35m', // magenta + warn: '\x1b[33m', // yellow + info: '\x1b[36m', // cyan + debug: '\x1b[37m', // white + reset: '\x1b[0m', + }; + + constructor(config: LoggerConfig) { + this.config = { + timestamp: true, + colors: true, + ...config, + }; + } + + private formatMessage(level: string, message: string): string { + const parts: string[] = []; + + if (this.config.timestamp) { + parts.push(`[${new Date().toISOString()}]`); + } + + if (this.config.prefix) { + parts.push(`[${this.config.prefix}]`); + } + + parts.push(`[${level.toUpperCase()}]`); + parts.push(message); + + return parts.join(' '); + } + + private log(level: LogLevel, levelName: string, message: string, data?: any) { + if (level > this.config.level) return; + + const formattedMessage = this.formatMessage(levelName, message); + const color = this.config.colors + ? this.colors[levelName as keyof typeof this.colors] + : ''; + const reset = this.config.colors ? this.colors.reset : ''; + + if (data !== undefined) { + console.log(`${color}${formattedMessage}${reset}`, data); + } else { + console.log(`${color}${formattedMessage}${reset}`); + } + } + + error(message: string, error?: Error | any) { + if (error instanceof Error) { + this.log(LogLevel.ERROR, 'error', `${message}: ${error.message}`); + if (this.config.level >= LogLevel.DEBUG) { + console.error(error.stack); + } + } else if (error) { + this.log(LogLevel.ERROR, 'error', message, error); + } else { + this.log(LogLevel.ERROR, 'error', message); + } + } + + critical(message: string, data?: any) { + this.log(LogLevel.CRITICAL, 'critical', message, data); + } + + warn(message: string, data?: any) { + this.log(LogLevel.WARN, 'warn', message, data); + } + + info(message: string, data?: any) { + this.log(LogLevel.INFO, 'info', message, data); + } + + debug(message: string, data?: any) { + this.log(LogLevel.DEBUG, 'debug', message, data); + } + + createChild(prefix: string): Logger { + return new Logger({ + ...this.config, + prefix: this.config.prefix ? `${this.config.prefix}:${prefix}` : prefix, + }); + } +} + +// Create default logger instance +const defaultConfig: LoggerConfig = { + level: process.env.LOG_LEVEL + ? LogLevel[process.env.LOG_LEVEL.toUpperCase() as keyof typeof LogLevel] || + LogLevel.ERROR + : process.env.NODE_ENV === 'production' + ? LogLevel.ERROR + : LogLevel.INFO, + timestamp: process.env.LOG_TIMESTAMP !== 'false', + colors: + process.env.LOG_COLORS !== 'false' && process.env.NODE_ENV !== 'production', +}; + +export const logger = new Logger(defaultConfig); + +// Helper to create a logger for a specific component +export function createLogger(prefix: string): Logger { + return logger.createChild(prefix); +} diff --git a/src/utils/misc.ts b/src/utils/misc.ts index 58ae4512b..9c14823c6 100644 --- a/src/utils/misc.ts +++ b/src/utils/misc.ts @@ -1,3 +1,6 @@ +import { Context } from 'hono'; +import { getRuntimeKey } from 'hono/adapter'; + export function toSnakeCase(str: string) { return str .replace(/([a-z])([A-Z])/g, '$1_$2') // Handle camelCase and PascalCase @@ -6,3 +9,13 @@ export function toSnakeCase(str: string) { .replace(/_+/g, '_') // Merge multiple underscores .toLowerCase(); } + +export const addBackgroundTask = ( + c: Context, + promise: Promise +) => { + if (getRuntimeKey() === 'workerd') { + c.executionCtx.waitUntil(promise); + } + // in other runtimes, the promise resolves in the background +}; diff --git a/wrangler.toml b/wrangler.toml index 378496704..1328e822b 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -3,9 +3,21 @@ compatibility_date = "2024-12-05" main = "src/index.ts" compatibility_flags = [ "nodejs_compat" ] +[durable_objects] +bindings = [ + { name = "API_RATE_LIMITER", class_name = "APIRateLimiter" }, + { name = "ATOMIC_COUNTER", class_name = "AtomicCounter" }, + { name = "CIRCUIT_BREAKER", class_name = "CircuitBreaker" }, +] + +[[kv_namespaces]] +binding = "KV_STORE" +id = "your-namespace-id" + [vars] ENVIRONMENT = 'dev' CUSTOM_HEADERS_TO_IGNORE = [] +ALBUS_BASEPATH = "https://albus.portkey.ai" # TODO: do we need this? # #Configuration for DEVELOPMENT environment From 6d123dbd0b4cbca4499cb386e6cad2dfcdb87ba2 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Tue, 30 Sep 2025 19:15:29 +0530 Subject: [PATCH 02/25] remvoe bindings --- wrangler.toml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/wrangler.toml b/wrangler.toml index 1328e822b..b01832af5 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -3,13 +3,6 @@ compatibility_date = "2024-12-05" main = "src/index.ts" compatibility_flags = [ "nodejs_compat" ] -[durable_objects] -bindings = [ - { name = "API_RATE_LIMITER", class_name = "APIRateLimiter" }, - { name = "ATOMIC_COUNTER", class_name = "AtomicCounter" }, - { name = "CIRCUIT_BREAKER", class_name = "CircuitBreaker" }, -] - [[kv_namespaces]] binding = "KV_STORE" id = "your-namespace-id" @@ -17,7 +10,6 @@ id = "your-namespace-id" [vars] ENVIRONMENT = 'dev' CUSTOM_HEADERS_TO_IGNORE = [] -ALBUS_BASEPATH = "https://albus.portkey.ai" # TODO: do we need this? # #Configuration for DEVELOPMENT environment From d55ecf8a1c155422e2cb2820e8073bca009c502f Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Tue, 30 Sep 2025 22:54:49 +0530 Subject: [PATCH 03/25] remove kv --- wrangler.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/wrangler.toml b/wrangler.toml index b01832af5..378496704 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -3,10 +3,6 @@ compatibility_date = "2024-12-05" main = "src/index.ts" compatibility_flags = [ "nodejs_compat" ] -[[kv_namespaces]] -binding = "KV_STORE" -id = "your-namespace-id" - [vars] ENVIRONMENT = 'dev' CUSTOM_HEADERS_TO_IGNORE = [] From 8407df16fd55a9c90d8a10288b099efbc03e2d4d Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Tue, 30 Sep 2025 23:12:12 +0530 Subject: [PATCH 04/25] rebase --- package-lock.json | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/package-lock.json b/package-lock.json index 383daf4f2..6baf1762b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,19 +1,19 @@ { "name": "@portkey-ai/gateway", - "version": "1.11.2", + "version": "1.12.3", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@portkey-ai/gateway", - "version": "1.11.2", + "version": "1.12.3", "hasInstallScript": true, "license": "MIT", "dependencies": { "@aws-crypto/sha256-js": "^5.2.0", "@cfworker/json-schema": "^4.0.3", "@hono/node-server": "^1.3.3", - "@hono/node-ws": "^1.0.4", + "@hono/node-ws": "^1.2.0", "@portkey-ai/mustache": "^2.1.3", "@smithy/signature-v4": "^2.1.1", "@types/mustache": "^4.2.5", @@ -1372,9 +1372,10 @@ } }, "node_modules/@hono/node-ws": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/@hono/node-ws/-/node-ws-1.0.4.tgz", - "integrity": "sha512-0j1TMp67U5ym0CIlvPKcKtD0f2ZjaS/EnhOxFLs3bVfV+/4WInBE7hVe2x/7PLEsNIUK9+jVL8lPd28rzTAcZg==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@hono/node-ws/-/node-ws-1.2.0.tgz", + "integrity": "sha512-OBPQ8OSHBw29mj00wT/xGYtB6HY54j0fNSdVZ7gZM3TUeq0So11GXaWtFf1xWxQNfumKIsj0wRuLKWfVsO5GgQ==", + "license": "MIT", "dependencies": { "ws": "^8.17.0" }, @@ -1382,7 +1383,8 @@ "node": ">=18.14.1" }, "peerDependencies": { - "@hono/node-server": "^1.11.1" + "@hono/node-server": "^1.11.1", + "hono": "^4.6.0" } }, "node_modules/@humanwhocodes/module-importer": { From 1d99b8ab03fca7c9f3177d33a6ecb4ae058745be Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni <47327611+narengogi@users.noreply.github.com> Date: Wed, 1 Oct 2025 12:33:52 +0530 Subject: [PATCH 05/25] Apply suggestion from @matter-code-review[bot] Co-authored-by: matter-code-review[bot] <150888575+matter-code-review[bot]@users.noreply.github.com> --- src/shared/services/cache/backends/cloudflareKV.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shared/services/cache/backends/cloudflareKV.ts b/src/shared/services/cache/backends/cloudflareKV.ts index 820d461e9..3df23d014 100644 --- a/src/shared/services/cache/backends/cloudflareKV.ts +++ b/src/shared/services/cache/backends/cloudflareKV.ts @@ -75,7 +75,7 @@ export class CloudflareKVCacheBackend implements CacheBackend { this.stats.hits++; return entry; } catch (error) { - logger.error('Redis get error:', error); + logger.error('Cloudflare KV get error:', error); this.stats.misses++; return null; } From 1c3d5c510f31785443993c46e37e9a8cf80b1492 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni <47327611+narengogi@users.noreply.github.com> Date: Wed, 1 Oct 2025 12:34:05 +0530 Subject: [PATCH 06/25] Apply suggestion from @matter-code-review[bot] Co-authored-by: matter-code-review[bot] <150888575+matter-code-review[bot]@users.noreply.github.com> --- src/shared/services/cache/backends/cloudflareKV.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shared/services/cache/backends/cloudflareKV.ts b/src/shared/services/cache/backends/cloudflareKV.ts index 3df23d014..d4e9b8782 100644 --- a/src/shared/services/cache/backends/cloudflareKV.ts +++ b/src/shared/services/cache/backends/cloudflareKV.ts @@ -136,7 +136,7 @@ export class CloudflareKVCacheBackend implements CacheBackend { return fullKeys.map((key) => key.substring(prefix.length)); } catch (error) { - logger.error('Redis keys error:', error); + logger.error('Cloudflare KV keys error:', error); return []; } } From 2b379fbd0cf14456e4322df932284109d62da10d Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni <47327611+narengogi@users.noreply.github.com> Date: Wed, 1 Oct 2025 12:34:17 +0530 Subject: [PATCH 07/25] Apply suggestion from @matter-code-review[bot] Co-authored-by: matter-code-review[bot] <150888575+matter-code-review[bot]@users.noreply.github.com> --- src/shared/services/cache/backends/cloudflareKV.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shared/services/cache/backends/cloudflareKV.ts b/src/shared/services/cache/backends/cloudflareKV.ts index d4e9b8782..1e9211438 100644 --- a/src/shared/services/cache/backends/cloudflareKV.ts +++ b/src/shared/services/cache/backends/cloudflareKV.ts @@ -151,7 +151,7 @@ export class CloudflareKVCacheBackend implements CacheBackend { size: keys.length, }; } catch (error) { - logger.error('Redis getStats error:', error); + logger.error('Cloudflare KV getStats error:', error); return { ...this.stats }; } } From 0bb75ddd27861a605d55cf2ad0a01434357ba148 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni <47327611+narengogi@users.noreply.github.com> Date: Wed, 1 Oct 2025 12:34:27 +0530 Subject: [PATCH 08/25] Apply suggestion from @matter-code-review[bot] Co-authored-by: matter-code-review[bot] <150888575+matter-code-review[bot]@users.noreply.github.com> --- src/shared/services/cache/backends/cloudflareKV.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shared/services/cache/backends/cloudflareKV.ts b/src/shared/services/cache/backends/cloudflareKV.ts index 1e9211438..60e80c6b3 100644 --- a/src/shared/services/cache/backends/cloudflareKV.ts +++ b/src/shared/services/cache/backends/cloudflareKV.ts @@ -162,9 +162,9 @@ export class CloudflareKVCacheBackend implements CacheBackend { } async cleanup(): Promise { - // Redis handles TTL automatically, so this is mostly a no-op + // Cloudflare KV handles TTL automatically, so this is mostly a no-op // We could scan for entries with manual expiration and clean them up - logger.debug('Redis cleanup - TTL handled automatically by Redis'); + logger.debug('Cloudflare KV cleanup - TTL handled automatically by Cloudflare KV'); } async close(): Promise { From 05e42a3934c9d18e47d69f4d24b8a948629edde8 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Wed, 1 Oct 2025 14:22:01 +0530 Subject: [PATCH 09/25] handle settings --- initializeSettings.ts | 7 +++---- src/shared/services/cache/backends/cloudflareKV.ts | 4 +++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/initializeSettings.ts b/initializeSettings.ts index 2ee0c6727..497f58888 100644 --- a/initializeSettings.ts +++ b/initializeSettings.ts @@ -44,13 +44,12 @@ const transformIntegrations = (integrations: any) => { }); }; -let settings: any = {}; +let settings: any = undefined; try { // @ts-expect-error const settingsFile = await import('./settings.json'); - if (!settingsFile) { - settings = undefined; - } else { + if (settingsFile) { + settings = {}; settings.organisationDetails = organisationDetails; if (settingsFile.integrations) { settings.integrations = transformIntegrations(settingsFile.integrations); diff --git a/src/shared/services/cache/backends/cloudflareKV.ts b/src/shared/services/cache/backends/cloudflareKV.ts index 60e80c6b3..fdcb10bc1 100644 --- a/src/shared/services/cache/backends/cloudflareKV.ts +++ b/src/shared/services/cache/backends/cloudflareKV.ts @@ -164,7 +164,9 @@ export class CloudflareKVCacheBackend implements CacheBackend { async cleanup(): Promise { // Cloudflare KV handles TTL automatically, so this is mostly a no-op // We could scan for entries with manual expiration and clean them up - logger.debug('Cloudflare KV cleanup - TTL handled automatically by Cloudflare KV'); + logger.debug( + 'Cloudflare KV cleanup - TTL handled automatically by Cloudflare KV' + ); } async close(): Promise { From 00c00e8285c812b11a926713ec9e45ada5bbe0b8 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Tue, 7 Oct 2025 19:35:07 +0530 Subject: [PATCH 10/25] handle redis rate limiter tokens rate limiting when tokens to decrement is greater than the available tokens --- src/shared/services/cache/utils/rateLimiter.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/shared/services/cache/utils/rateLimiter.ts b/src/shared/services/cache/utils/rateLimiter.ts index b6447702c..bba149cda 100644 --- a/src/shared/services/cache/utils/rateLimiter.ts +++ b/src/shared/services/cache/utils/rateLimiter.ts @@ -59,6 +59,10 @@ if tokens >= units then tokensModified = true end else + if tokens > 0 then + tokensModified = true + end + tokens = 0 local needed = units - currentTokens local rate = capacity / windowSize waitTime = (rate > 0) and math.floor(needed / rate) or -1 From a29315e48e4528c909d20fcfc7cc90adce2a4813 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 12:17:30 +0530 Subject: [PATCH 11/25] remove cache backend changes --- src/index.ts | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/index.ts b/src/index.ts index 70850db8b..66f4495ed 100644 --- a/src/index.ts +++ b/src/index.ts @@ -48,18 +48,6 @@ import { messagesCountTokensHandler } from './handlers/messagesCountTokensHandle const app = new Hono(); const runtime = getRuntimeKey(); -// cache beackends will only get created during worker or app initialization depending on the runtime -if (getRuntimeKey() === 'workerd') { - app.use('*', (c: Context, next) => { - createCacheBackendsCF(env(c)); - return next(); - }); -} else if (getRuntimeKey() === 'node' && process.env.REDIS_CONNECTION_STRING) { - createCacheBackendsRedis(process.env.REDIS_CONNECTION_STRING); -} else { - createCacheBackendsLocal(); -} - /** * Middleware that conditionally applies compression middleware based on the runtime. * Compression is automatically handled for lagon and workerd runtimes From 9cc932d582e0dc127737b9ac67cc930de91bf261 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 12:50:10 +0530 Subject: [PATCH 12/25] remove unused imports --- src/index.ts | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/index.ts b/src/index.ts index 66f4495ed..35653b922 100644 --- a/src/index.ts +++ b/src/index.ts @@ -37,17 +37,11 @@ import { imageEditsHandler } from './handlers/imageEditsHandler'; // Config import conf from '../conf.json'; import modelResponsesHandler from './handlers/modelResponsesHandler'; -import { - createCacheBackendsLocal, - createCacheBackendsRedis, - createCacheBackendsCF, -} from './shared/services/cache'; import { messagesCountTokensHandler } from './handlers/messagesCountTokensHandler'; // Create a new Hono server instance const app = new Hono(); const runtime = getRuntimeKey(); - /** * Middleware that conditionally applies compression middleware based on the runtime. * Compression is automatically handled for lagon and workerd runtimes From 2505510288c3989fd677081bfce4785a0c2fa60c Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 12:51:52 +0530 Subject: [PATCH 13/25] update settings example --- settings.example.json | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/settings.example.json b/settings.example.json index 5d9790d42..e3b2534d5 100644 --- a/settings.example.json +++ b/settings.example.json @@ -11,13 +11,11 @@ "type": "requests", "unit": "rph", "value": 3 - } - ], - "usage_limits": [ + }, { "type": "tokens", - "credit_limit": 1000000, - "periodic_reset": "weekly" + "unit": "rph", + "value": 3000 } ], "models": [ From bc56171109ea252e0fbfac927224e8f8b1eb4dad Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 12:54:06 +0530 Subject: [PATCH 14/25] update settings initializer --- initializeSettings.ts | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/initializeSettings.ts b/initializeSettings.ts index 497f58888..5e0da3e6e 100644 --- a/initializeSettings.ts +++ b/initializeSettings.ts @@ -31,10 +31,10 @@ const transformIntegrations = (integrations: any) => { slug: integration.slug, usage_limits: null, status: 'active', - integration_id: '1234567890', + integration_id: integration.slug, object: 'virtual-key', integration_details: { - id: '1234567890', + id: integration.slug, slug: integration.slug, usage_limits: integration.usage_limits, rate_limits: integration.rate_limits, @@ -46,7 +46,6 @@ const transformIntegrations = (integrations: any) => { let settings: any = undefined; try { - // @ts-expect-error const settingsFile = await import('./settings.json'); if (settingsFile) { settings = {}; From 99f6262585b3a2d42affc7161eb6fd21803eebcb Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 14:56:07 +0530 Subject: [PATCH 15/25] dont hardcode id --- initializeSettings.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/initializeSettings.ts b/initializeSettings.ts index 5e0da3e6e..a75b00be6 100644 --- a/initializeSettings.ts +++ b/initializeSettings.ts @@ -20,7 +20,7 @@ const organisationDetails = { const transformIntegrations = (integrations: any) => { return integrations.map((integration: any) => { return { - id: '1234567890', //need to do consistent hashing for caching + id: integration.slug, //need to do consistent hashing for caching ai_provider_name: integration.provider, model_config: { ...integration.credentials, From d331fd8bc734785f5a05080bcea26481a31dbb59 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 14:58:41 +0530 Subject: [PATCH 16/25] remove unused variable --- src/index.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/index.ts b/src/index.ts index 35653b922..5ac987027 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,7 +8,7 @@ import { Context, Hono } from 'hono'; import { prettyJSON } from 'hono/pretty-json'; import { HTTPException } from 'hono/http-exception'; import { compress } from 'hono/compress'; -import { env, getRuntimeKey } from 'hono/adapter'; +import { getRuntimeKey } from 'hono/adapter'; // import { env } from 'hono/adapter' // Have to set this up for multi-environment deployment // Middlewares @@ -38,6 +38,7 @@ import { imageEditsHandler } from './handlers/imageEditsHandler'; import conf from '../conf.json'; import modelResponsesHandler from './handlers/modelResponsesHandler'; import { messagesCountTokensHandler } from './handlers/messagesCountTokensHandler'; +import { portkey } from './middlewares/portkey'; // Create a new Hono server instance const app = new Hono(); @@ -97,6 +98,7 @@ app.get('/v1/models', modelsHandler); // Use hooks middleware for all routes app.use('*', hooks); +app.use('*', portkey()); if (conf.cache === true) { app.use('*', memoryCache()); From f4cdbe37573dfe69e0cc0d004f76d1d41c2e5811 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 15:39:12 +0530 Subject: [PATCH 17/25] handle nulls --- initializeSettings.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/initializeSettings.ts b/initializeSettings.ts index a75b00be6..b6704ecd9 100644 --- a/initializeSettings.ts +++ b/initializeSettings.ts @@ -1,10 +1,10 @@ -const organisationDetails = { +export const defaultOrganisationDetails = { id: '00000000-0000-0000-0000-000000000000', name: 'Portkey self hosted', settings: { debug_log: 1, is_virtual_key_limit_enabled: 1, - allowed_guardrails: ['BASIC'], + allowed_guardrails: ['BASIC', 'PARTNER', 'PRO'], }, workspaceDetails: {}, defaults: { @@ -49,7 +49,7 @@ try { const settingsFile = await import('./settings.json'); if (settingsFile) { settings = {}; - settings.organisationDetails = organisationDetails; + settings.organisationDetails = defaultOrganisationDetails; if (settingsFile.integrations) { settings.integrations = transformIntegrations(settingsFile.integrations); } From d2ca7efe1bc7b128f2b1c2b26cf5cac8eabb5184 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 15:40:23 +0530 Subject: [PATCH 18/25] remove import --- src/index.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/index.ts b/src/index.ts index 5ac987027..15f3d43ca 100644 --- a/src/index.ts +++ b/src/index.ts @@ -38,7 +38,6 @@ import { imageEditsHandler } from './handlers/imageEditsHandler'; import conf from '../conf.json'; import modelResponsesHandler from './handlers/modelResponsesHandler'; import { messagesCountTokensHandler } from './handlers/messagesCountTokensHandler'; -import { portkey } from './middlewares/portkey'; // Create a new Hono server instance const app = new Hono(); @@ -98,7 +97,6 @@ app.get('/v1/models', modelsHandler); // Use hooks middleware for all routes app.use('*', hooks); -app.use('*', portkey()); if (conf.cache === true) { app.use('*', memoryCache()); From b372b2fa7ce8f9886c94001fcf375eb98170e6b4 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 16:31:18 +0530 Subject: [PATCH 19/25] delete transfer encoding for node --- src/handlers/services/responseService.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/handlers/services/responseService.ts b/src/handlers/services/responseService.ts index 8957ce5b2..1fe2b4dfd 100644 --- a/src/handlers/services/responseService.ts +++ b/src/handlers/services/responseService.ts @@ -124,9 +124,9 @@ export class ResponseService { // Remove headers directly if (getRuntimeKey() == 'node') { response.headers.delete('content-encoding'); + response.headers.delete('transfer-encoding'); } response.headers.delete('content-length'); - // response.headers.delete('transfer-encoding'); return response; } From 8a0eb99126f196897895a711c3f7d60224cbc35c Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Thu, 9 Oct 2025 21:58:30 +0530 Subject: [PATCH 20/25] rate limits with resets --- initializeSettings.ts | 38 +- src/handlers/adminRoutesHandler.ts | 132 ++ src/index.ts | 15 + src/middlewares/portkey/circuitBreaker.ts | 418 ++++ src/middlewares/portkey/globals.ts | 267 +++ src/middlewares/portkey/handlers/albus.ts | 453 ++++ src/middlewares/portkey/handlers/cache.ts | 157 ++ .../portkey/handlers/configFile.ts | 41 + src/middlewares/portkey/handlers/feedback.ts | 25 + src/middlewares/portkey/handlers/helpers.ts | 1823 +++++++++++++++++ src/middlewares/portkey/handlers/hooks.ts | 137 ++ src/middlewares/portkey/handlers/kv.ts | 80 + src/middlewares/portkey/handlers/logger.ts | 49 + .../portkey/handlers/rateLimits.ts | 247 +++ src/middlewares/portkey/handlers/realtime.ts | 78 + src/middlewares/portkey/handlers/stream.ts | 823 ++++++++ src/middlewares/portkey/handlers/usage.ts | 143 ++ src/middlewares/portkey/index.ts | 696 +++++++ src/middlewares/portkey/mustache.d.ts | 4 + src/middlewares/portkey/types.ts | 509 +++++ src/middlewares/portkey/utils.ts | 717 +++++++ .../utils/anthropicMessagesStreamParser.ts | 159 ++ src/public/index.html | 1048 +++++++++- src/shared/services/cache/backends/redis.ts | 38 +- src/shared/services/cache/index.ts | 4 + .../services/cache/utils/rateLimiter.ts | 16 +- 26 files changed, 8073 insertions(+), 44 deletions(-) create mode 100644 src/handlers/adminRoutesHandler.ts create mode 100644 src/middlewares/portkey/circuitBreaker.ts create mode 100644 src/middlewares/portkey/globals.ts create mode 100644 src/middlewares/portkey/handlers/albus.ts create mode 100644 src/middlewares/portkey/handlers/cache.ts create mode 100644 src/middlewares/portkey/handlers/configFile.ts create mode 100644 src/middlewares/portkey/handlers/feedback.ts create mode 100644 src/middlewares/portkey/handlers/helpers.ts create mode 100644 src/middlewares/portkey/handlers/hooks.ts create mode 100644 src/middlewares/portkey/handlers/kv.ts create mode 100644 src/middlewares/portkey/handlers/logger.ts create mode 100644 src/middlewares/portkey/handlers/rateLimits.ts create mode 100644 src/middlewares/portkey/handlers/realtime.ts create mode 100644 src/middlewares/portkey/handlers/stream.ts create mode 100644 src/middlewares/portkey/handlers/usage.ts create mode 100644 src/middlewares/portkey/index.ts create mode 100644 src/middlewares/portkey/mustache.d.ts create mode 100644 src/middlewares/portkey/types.ts create mode 100644 src/middlewares/portkey/utils.ts create mode 100644 src/middlewares/portkey/utils/anthropicMessagesStreamParser.ts diff --git a/initializeSettings.ts b/initializeSettings.ts index b6704ecd9..f453246eb 100644 --- a/initializeSettings.ts +++ b/initializeSettings.ts @@ -45,20 +45,30 @@ const transformIntegrations = (integrations: any) => { }; let settings: any = undefined; -try { - const settingsFile = await import('./settings.json'); - if (settingsFile) { - settings = {}; - settings.organisationDetails = defaultOrganisationDetails; - if (settingsFile.integrations) { - settings.integrations = transformIntegrations(settingsFile.integrations); +const loadSettings = async () => { + try { + const settingsFile = await import('./settings.json'); + if (settingsFile) { + settings = {}; + settings.organisationDetails = defaultOrganisationDetails; + if (settingsFile.integrations) { + settings.integrations = transformIntegrations( + settingsFile.integrations + ); + } } + } catch (error) { + console.log( + 'WARNING: Unable to import settings from the path, please make sure the file exists', + error + ); } -} catch (error) { - console.log( - 'WARNING: Unable to import settings from the path, please make sure the file exists', - error - ); -} +}; + +loadSettings(); + +const refreshSettings = async () => { + await loadSettings(); +}; -export { settings }; +export { settings, refreshSettings }; diff --git a/src/handlers/adminRoutesHandler.ts b/src/handlers/adminRoutesHandler.ts new file mode 100644 index 000000000..b4bd1719e --- /dev/null +++ b/src/handlers/adminRoutesHandler.ts @@ -0,0 +1,132 @@ +import { Context, Hono } from 'hono'; +import { refreshSettings } from '../../initializeSettings'; +import { getDefaultCache } from '../shared/services/cache'; +import { settings } from '../../initializeSettings'; +import { generateRateLimitKey } from '../middlewares/portkey/handlers/rateLimits'; +import { RateLimiterKeyTypes } from '../globals'; + +/** + * Helper function to authenticate admin requests + */ +async function authenticateAdmin(c: Context): Promise { + try { + const fs = await import('fs/promises'); + const path = await import('path'); + const settingsPath = path.join(process.cwd(), 'settings.json'); + const settingsData = await fs.readFile(settingsPath, 'utf-8'); + const settings = JSON.parse(settingsData); + + const authHeader = + c.req.header('Authorization') || c.req.header('authorization'); + const providedKey = + authHeader?.replace('Bearer ', '') || c.req.header('x-admin-api-key'); + + return providedKey === settings.adminApiKey; + } catch (error) { + console.error('Error authenticating admin:', error); + return false; + } +} + +/** + * GET route for /admin/settings + * Serves the settings configuration file (requires admin authentication) + */ +async function getSettingsHandler(c: Context): Promise { + const isAuthenticated = await authenticateAdmin(c); + if (!isAuthenticated) { + return c.json({ error: 'Unauthorized' }, 401); + } + + try { + const fs = await import('fs/promises'); + const path = await import('path'); + const settingsPath = path.join(process.cwd(), 'settings.json'); + const settingsData = await fs.readFile(settingsPath, 'utf-8'); + return c.json(JSON.parse(settingsData)); + } catch (error) { + console.error('Error reading settings.json:', error); + return c.json({ error: 'Settings file not found' }, 404); + } +} + +/** + * PUT route for /admin/settings + * Updates the settings configuration file (requires admin authentication) + */ +async function putSettingsHandler(c: Context): Promise { + const isAuthenticated = await authenticateAdmin(c); + if (!isAuthenticated) { + return c.json({ error: 'Unauthorized' }, 401); + } + + try { + const fs = await import('fs/promises'); + const path = await import('path'); + const settingsPath = path.join(process.cwd(), 'settings.json'); + const body = await c.req.json(); + await fs.writeFile(settingsPath, JSON.stringify(body, null, 2)); + await refreshSettings(); + return c.json({ success: true }); + } catch (error) { + console.error('Error writing settings.json:', error); + return c.json({ error: 'Failed to save settings' }, 500); + } +} + +async function resetIntegrationRateLimitHandler(c: Context): Promise { + const isAuthenticated = await authenticateAdmin(c); + if (!isAuthenticated) { + return c.json({ error: 'Unauthorized' }, 401); + } + + try { + const integrationId = c.req.param('integrationId'); + const organisationId = settings.organisationDetails.id; + const workspaceId = settings.organisationDetails?.workspaceDetails?.id; + const key = c.req.param('key'); + const rateLimit = settings.integrations + .find((integration) => integration.slug === integrationId)?.integration_details + ?.rate_limits.find( + (rateLimit) => rateLimit.type === key + ); + const workspaceKey = `${integrationId}-${workspaceId}`; + const rateLimitKey = generateRateLimitKey( + organisationId, + rateLimit.type, + RateLimiterKeyTypes.INTEGRATION_WORKSPACE, + workspaceKey, + rateLimit.unit + ); + const finalKey = `{rate:${rateLimitKey}}:${key}`; + const cache = getDefaultCache(); + await cache.delete(finalKey); + return c.json({ success: true }); + } catch (error) { + console.error('Error deleting cache:', error); + return c.json({ error: 'Failed to delete cache' }, 500); + } +} + +/** + * Admin routes handler + * Handles all /admin/* routes + */ +export function adminRoutesHandler() { + const adminApp = new Hono(); + + // Settings routes + adminApp.get('/settings', getSettingsHandler); + adminApp.put('/settings', putSettingsHandler); + adminApp.put( + '/integrations/ratelimit/:integrationId/:key/reset', + resetIntegrationRateLimitHandler + ); + + // Add more admin routes here as needed + // adminApp.get('/users', getUsersHandler); + // adminApp.post('/users', createUserHandler); + // etc. + + return adminApp; +} diff --git a/src/index.ts b/src/index.ts index 15f3d43ca..1f3ffadfa 100644 --- a/src/index.ts +++ b/src/index.ts @@ -38,10 +38,18 @@ import { imageEditsHandler } from './handlers/imageEditsHandler'; import conf from '../conf.json'; import modelResponsesHandler from './handlers/modelResponsesHandler'; import { messagesCountTokensHandler } from './handlers/messagesCountTokensHandler'; +import { portkey } from './middlewares/portkey'; +import { adminRoutesHandler } from './handlers/adminRoutesHandler'; +import { createCacheBackendsRedis } from './shared/services/cache'; // Create a new Hono server instance const app = new Hono(); const runtime = getRuntimeKey(); + +if (runtime === 'node' && process.env.REDIS_CONNECTION_STRING) { + createCacheBackendsRedis(process.env.REDIS_CONNECTION_STRING); +} + /** * Middleware that conditionally applies compression middleware based on the runtime. * Compression is automatically handled for lagon and workerd runtimes @@ -84,6 +92,12 @@ if (runtime === 'node') { */ app.get('/', (c) => c.text('AI Gateway says hey!')); +/** + * Admin routes + * All /admin/* routes are handled by the admin routes handler + */ +app.route('/admin', adminRoutesHandler()); + // Use prettyJSON middleware for all routes app.use('*', prettyJSON()); @@ -97,6 +111,7 @@ app.get('/v1/models', modelsHandler); // Use hooks middleware for all routes app.use('*', hooks); +app.use('*', portkey()); if (conf.cache === true) { app.use('*', memoryCache()); diff --git a/src/middlewares/portkey/circuitBreaker.ts b/src/middlewares/portkey/circuitBreaker.ts new file mode 100644 index 000000000..605efe484 --- /dev/null +++ b/src/middlewares/portkey/circuitBreaker.ts @@ -0,0 +1,418 @@ +import { Context } from 'hono'; +import { env } from 'hono/adapter'; +import { StrategyModes } from '../../types/requestBody'; + +export interface CircuitBreakerConfig { + failure_threshold: number; + failure_threshold_percentage?: number; + cooldown_interval: number; // in milliseconds + failure_status_codes?: number[]; + minimum_requests?: number; +} + +interface CircuitBreakerStatus { + path: string; + is_open: boolean; + failure_count: number; + success_count: number; + minimum_requests?: number; + first_failure_time?: number; + cb_config: CircuitBreakerConfig; +} + +export interface CircuitBreakerContext { + configId: string; + pathStatusMap: Record; +} + +interface CircuitBreakerDOData { + [targetPath: string]: { + failure_count: number; + first_failure_time?: number; + success_count: number; + minimum_requests?: number; + }; +} + +// Helper function to get Durable Object stub +function getCircuitBreakerDO(env: any, configId: string) { + const id = env.CIRCUIT_BREAKER.idFromName(configId); + return env.CIRCUIT_BREAKER.get(id); +} + +/** + * Extracts circuit breaker configurations from config and creates path mappings + */ +export const extractCircuitBreakerConfigs = ( + config: Record, + configId: string, + parentPath: string = 'config', + parentBreakerConfig?: CircuitBreakerConfig +): CircuitBreakerContext => { + const pathStatusMap: Record = {}; + + function recursiveExtractBreakers( + currentConfig: Record, + currentPath: string, + inheritedCBConfig?: CircuitBreakerConfig + ) { + // Get breaker config for current level - inherit from parent if not specified + let currentCBConfig: CircuitBreakerConfig | undefined; + + if ( + (currentConfig.strategy?.cb_config?.failure_threshold || + currentConfig.strategy?.cb_config?.failure_threshold_percentage) && + currentConfig.strategy?.cb_config?.cooldown_interval + ) { + currentCBConfig = { + failure_threshold: currentConfig.strategy.cb_config.failure_threshold, + failure_threshold_percentage: + currentConfig.strategy.cb_config.failure_threshold_percentage, + cooldown_interval: Math.max( + currentConfig.strategy.cb_config.cooldown_interval, + 30000 + ), // minimum cooldown interval of 30 seconds + minimum_requests: currentConfig.strategy.cb_config.minimum_requests, + failure_status_codes: + currentConfig.strategy.cb_config.failure_status_codes, + }; + } else { + currentCBConfig = inheritedCBConfig; + } + + // If this is a target (has virtual_key) or a strategy with targets + if (currentConfig.virtual_key) { + if (currentCBConfig) { + pathStatusMap[currentPath] = { + path: currentPath, + is_open: false, + failure_count: 0, + success_count: 0, + cb_config: currentCBConfig, + }; + } + } + + // If this is a conditional strategy, ignore circuit breaker + if (currentConfig.strategy?.mode === StrategyModes.CONDITIONAL) { + currentCBConfig = undefined; + } + + // Process targets recursively + if (currentConfig.targets) { + currentConfig.targets.forEach((target: any, index: number) => { + const targetPath = `${currentPath}.targets[${index}]`; + recursiveExtractBreakers(target, targetPath, currentCBConfig); + }); + } + } + + recursiveExtractBreakers(config, parentPath, parentBreakerConfig); + + return { + configId, + pathStatusMap, + }; +}; + +/** + * Checks circuit breaker status from Durable Object storage + */ +export const checkCircuitBreakerStatus = async ( + env: any, + circuitBreakerContext: CircuitBreakerContext +): Promise => { + const { configId, pathStatusMap } = circuitBreakerContext; + if (Object.keys(pathStatusMap).length === 0) { + return null; + } + + const updatedPathStatusMap: Record = {}; + + try { + // Get circuit breaker data from Durable Object + const circuitBreakerDO = getCircuitBreakerDO(env, configId); + const response = await circuitBreakerDO.fetch( + new Request('https://dummy/getStatus') + ); + const circuitBreakerData: CircuitBreakerDOData = await response.json(); + + const now = Date.now(); + + let isUpdated = false; + + // Check each path's circuit breaker status + for (const [path, status] of Object.entries(pathStatusMap)) { + const pathData = circuitBreakerData[path]; + const failureCount = pathData?.failure_count || 0; + const successCount = pathData?.success_count || 0; + const firstFailureTime = pathData?.first_failure_time; + const minimumRequests = pathData?.minimum_requests || 5; + let currentFailurePercentage = 0; + if ( + (failureCount || successCount) && + failureCount + successCount >= minimumRequests + ) { + currentFailurePercentage = + (100 * failureCount) / (failureCount + successCount); + } + + const { + failure_threshold, + cooldown_interval, + failure_threshold_percentage, + } = status.cb_config; + + let isOpen = false; + + // Check if circuit should be open + if ( + (failure_threshold && failureCount >= failure_threshold) || + (failure_threshold_percentage && + currentFailurePercentage >= failure_threshold_percentage) + ) { + // If cooldown period hasn't passed, keep circuit open + if (firstFailureTime && now - firstFailureTime < cooldown_interval) { + isOpen = true; + } else { + // Cooldown period has passed, reset failure count + delete circuitBreakerData[path]; + isUpdated = true; + } + } + + updatedPathStatusMap[path] = { + ...status, + failure_count: failureCount, + first_failure_time: firstFailureTime, + success_count: successCount, + minimum_requests: minimumRequests, + is_open: isOpen, + }; + } + + // Update DO if needed (reset expired circuits) + if (isUpdated) { + await circuitBreakerDO.fetch( + new Request('https://dummy/update', { + method: 'POST', + body: JSON.stringify(circuitBreakerData), + headers: { 'Content-Type': 'application/json' }, + }) + ); + } + } catch (error) { + console.error( + `Error checking circuit breaker status for configId ${configId}:`, + error + ); + // Default to close circuits on error + for (const [path, status] of Object.entries(pathStatusMap)) { + updatedPathStatusMap[path] = { + ...status, + is_open: false, + }; + } + } + + return { + configId, + pathStatusMap: updatedPathStatusMap, + }; +}; + +/** + * Updates config with circuit breaker status + */ +export const getCircuitBreakerMappedConfig = ( + config: Record, + circuitBreakerContext: CircuitBreakerContext +): Record => { + const mappedConfig = { ...config }; + const { pathStatusMap } = circuitBreakerContext; + + function recursiveUpdateStatus( + currentConfig: Record, + currentPath: string + ): boolean { + let allTargetsOpen = true; + let hasTargets = false; + + // Process targets recursively + if (currentConfig.targets) { + hasTargets = true; + currentConfig.targets.forEach((target: any, index: number) => { + const targetPath = `${currentPath}.targets[${index}]`; + const targetOpen = recursiveUpdateStatus(target, targetPath); + + // Update target's is_open status + if (pathStatusMap[targetPath]) { + target.is_open = pathStatusMap[targetPath].is_open || targetOpen; + target.cb_config = pathStatusMap[targetPath].cb_config; + } + + if (!target.is_open) { + allTargetsOpen = false; + } + }); + } + + // Update current level status + if (pathStatusMap[currentPath]) { + // If this level has its own circuit breaker status, use it + currentConfig.is_open = pathStatusMap[currentPath].is_open; + currentConfig.cb_config = pathStatusMap[currentPath].cb_config; + + // If all targets are open, mark strategy as open + if (hasTargets && allTargetsOpen) { + currentConfig.is_open = true; + } + + return currentConfig.is_open; + } + + // If no circuit breaker config at this level, return whether all targets are open + return hasTargets ? allTargetsOpen : false; + } + + recursiveUpdateStatus(mappedConfig, 'config'); + mappedConfig.id = circuitBreakerContext.configId; + return mappedConfig; +}; + +/** + * Records a failure for circuit breaker using Durable Object + */ +export const recordCircuitBreakerFailure = async ( + env: any, + configId: string, + cbConfig: CircuitBreakerConfig, + targetPath: string, + errorStatusCode: number +): Promise => { + const failureStatusCodes = getCircuitBreakerStatusCodes(cbConfig); + if (!isCircuitBreakerFailure(errorStatusCode, failureStatusCodes)) { + return; + } + const now = Date.now(); + + try { + const circuitBreakerDO = getCircuitBreakerDO(env, configId); + await circuitBreakerDO.fetch( + new Request('https://dummy/recordFailure', { + method: 'POST', + body: JSON.stringify({ path: targetPath, timestamp: now }), + headers: { 'Content-Type': 'application/json' }, + }) + ); + } catch (error) { + console.error( + `Error recording circuit breaker failure for ${targetPath}:`, + error + ); + } +}; + +/** + * Records a success for circuit breaker using Durable Object + */ +export const recordCircuitBreakerSuccess = async ( + env: any, + configId: string, + targetPath: string +): Promise => { + try { + const circuitBreakerDO = getCircuitBreakerDO(env, configId); + await circuitBreakerDO.fetch( + new Request('https://dummy/recordSuccess', { + method: 'POST', + body: JSON.stringify({ path: targetPath }), + headers: { 'Content-Type': 'application/json' }, + }) + ); + } catch (error) { + console.error( + `Error recording circuit breaker success for ${targetPath}:`, + error + ); + } +}; + +// Helper function to determine if status code should trigger circuit breaker +export function isCircuitBreakerFailure( + statusCode: number, + failureStatusCodes?: number[] +): boolean { + return ( + failureStatusCodes?.includes(statusCode) || + (!failureStatusCodes && statusCode >= 500) + ); +} + +export async function handleCircuitBreakerResponse( + response: Response | undefined, + configId: string, + cbConfig: CircuitBreakerConfig, + targetPath: string, + c: Context +): Promise { + if (!cbConfig) return; + + if (response?.ok) { + await recordCircuitBreakerSuccess(env(c), configId, targetPath); + } else if (response) { + await recordCircuitBreakerFailure( + env(c), + configId, + cbConfig, + targetPath, + response.status + ); + } +} + +export function generateCircuitBreakerConfigId( + configSlug: string, + workspaceId: string, + organisationId: string +): string { + return `${organisationId}:${workspaceId}:${configSlug}`; +} + +export async function destroyCircuitBreakerConfig( + env: any, + configSlug: string, + workspaceId: string, + organisationId: string +): Promise { + const configId = generateCircuitBreakerConfigId( + configSlug, + workspaceId, + organisationId + ); + + try { + const circuitBreakerDO = getCircuitBreakerDO(env, configId); + await circuitBreakerDO.fetch( + new Request('https://dummy/destroy', { + method: 'POST', + }) + ); + return true; + } catch (error) { + console.error( + `Error destroying circuit breaker config for ${configId}:`, + error + ); + return false; + } +} + +function getCircuitBreakerStatusCodes( + cbConfig: CircuitBreakerConfig +): number[] | undefined { + if (!cbConfig) { + return undefined; + } + return cbConfig.failure_status_codes; +} diff --git a/src/middlewares/portkey/globals.ts b/src/middlewares/portkey/globals.ts new file mode 100644 index 000000000..55f1995e8 --- /dev/null +++ b/src/middlewares/portkey/globals.ts @@ -0,0 +1,267 @@ +export const OPEN_AI: string = 'openai'; +export const COHERE: string = 'cohere'; +export const AZURE_OPEN_AI: string = 'azure-openai'; +export const ANTHROPIC: string = 'anthropic'; +export const ANYSCALE: string = 'anyscale'; +export const GOOGLE: string = 'google'; +export const TOGETHER_AI: string = 'together-ai'; +export const PERPLEXITY_AI: string = 'perplexity-ai'; +export const MISTRAL_AI: string = 'mistral-ai'; +export const OLLAMA: string = 'ollama'; +export const BEDROCK: string = 'bedrock'; +export const SAGEMAKER: string = 'sagemaker'; +export const VERTEX_AI: string = 'vertex-ai'; +export const WORKERS_AI: string = 'workers-ai'; +export const NOVITA_AI: string = 'novita-ai'; +export const AZURE_AI: string = 'azure-ai'; + +export const HEADER_KEYS = { + API_KEY: 'x-portkey-api-key', + MODE: 'x-portkey-mode', + CONFIG: 'x-portkey-config', + CACHE: 'x-portkey-cache', + CACHE_TTL: 'x-portkey-cache-ttl', + CACHE_REFRESH: 'x-portkey-cache-force-refresh', + RETRIES: 'x-portkey-retry-count', + TRACE_ID: 'x-portkey-trace-id', + METADATA: 'x-portkey-metadata', + PROMPT_VERSION_ID: 'x-portkey-prompt-version-id', + PROMPT_ID: 'x-portkey-prompt-id', + ORGANISATION_DETAILS: 'x-auth-organisation-details', + RUNTIME: 'x-portkey-runtime', + RUNTIME_VERSION: 'x-portkey-runtime-version', + PACKAGE_VERSION: 'x-portkey-package-version', + CONFIG_VERSION: 'x-portkey-config-version', + PROVIDER: 'x-portkey-provider', + VIRTUAL_KEY: 'x-portkey-virtual-key', + VIRTUAL_KEY_EXHAUSTED: 'x-portkey-virtual-key-exhausted', + VIRTUAL_KEY_EXPIRED: 'x-portkey-virtual-key-expired', + VIRTUAL_KEY_USAGE_LIMITS: 'x-portkey-virtual-key-usage-limits', + VIRTUAL_KEY_RATE_LIMITS: 'x-portkey-virtual-key-rate-limits', + VIRTUAL_KEY_ID: 'x-portkey-virtual-key-id', + VIRTUAL_KEY_DETAILS: 'x-portkey-virtual-key-details', + INTEGRATION_RATE_LIMITS: 'x-portkey-integration-rate-limits', + INTEGRATION_USAGE_LIMITS: 'x-portkey-integration-usage-limits', + INTEGRATION_EXHAUSTED: 'x-portkey-integration-exhausted', + INTEGRATION_ID: 'x-portkey-integration-id', + INTEGRATION_SLUG: 'x-portkey-integration-slug', + INTEGRATION_MODELS: 'x-portkey-integration-model-details', + INTEGRATION_DETAILS: 'x-portkey-integration-details', + CONFIG_SLUG: 'x-portkey-config-slug', + PROMPT_SLUG: 'x-portkey-prompt-slug', + AZURE_RESOURCE: 'x-portkey-azure-resource-name', + AZURE_DEPLOYMENT: 'x-portkey-azure-deployment-id', + AZURE_API_VERSION: 'x-portkey-azure-api-version', + AZURE_MODEL_NAME: 'x-portkey-azure-model-name', + REFRESH_PROMPT_CACHE: 'x-portkey-refresh-prompt-cache', + CACHE_CONTROL: 'cache-control', + AWS_AUTH_TYPE: 'x-portkey-aws-auth-type', + AWS_ROLE_ARN: 'x-portkey-aws-role-arn', + AWS_EXTERNAL_ID: 'x-portkey-aws-external-id', + BEDROCK_ACCESS_KEY_ID: 'x-portkey-aws-access-key-id', + BEDROCK_SECRET_ACCESS_KEY: 'x-portkey-aws-secret-access-key', + BEDROCK_REGION: 'x-portkey-aws-region', + BEDROCK_SESSION_TOKEN: 'x-portkey-aws-session-token', + SAGEMAKER_CUSTOM_ATTRIBUTES: 'x-portkey-amzn-sagemaker-custom-attributes', + SAGEMAKER_TARGET_MODEL: 'x-portkey-amzn-sagemaker-target-model', + SAGEMAKER_TARGET_VARIANT: 'x-portkey-amzn-sagemaker-target-variant', + SAGEMAKER_TARGET_CONTAINER_HOSTNAME: + 'x-portkey-amzn-sagemaker-target-container-hostname', + SAGEMAKER_INFERENCE_ID: 'x-portkey-amzn-sagemaker-inference-id', + SAGEMAKER_ENABLE_EXPLANATIONS: 'x-portkey-amzn-sagemaker-enable-explanations', + SAGEMAKER_INFERENCE_COMPONENT: 'x-portkey-amzn-sagemaker-inference-component', + SAGEMAKER_SESSION_ID: 'x-portkey-amzn-sagemaker-session-id', + SAGEMAKER_MODEL_NAME: 'x-portkey-amzn-sagemaker-model-name', + DEBUG_LOG_SETTING: 'x-portkey-debug', + VERTEX_AI_PROJECT_ID: 'x-portkey-vertex-project-id', + VERTEX_AI_REGION: 'x-portkey-vertex-region', + VERTEX_SERVICE_ACCOUNT_JSON: 'x-portkey-vertex-service-account-json', + WORKERS_AI_ACCOUNT_ID: 'x-portkey-workers-ai-account-id', + OPEN_AI_PROJECT: 'x-portkey-openai-project', + OPEN_AI_ORGANIZATION: 'x-portkey-openai-organization', + SPAN_ID: 'x-portkey-span-id', + SPAN_NAME: 'x-portkey-span-name', + PARENT_SPAN_ID: 'x-portkey-parent-span-id', + AZURE_DEPLOYMENT_NAME: 'x-portkey-azure-deployment-name', + AZURE_REGION: 'x-portkey-azure-region', + AZURE_ENDPOINT_NAME: 'x-portkey-azure-endpoint-name', + AZURE_DEPLOYMENT_TYPE: 'x-portkey-azure-deployment-type', + AZURE_AUTH_MODE: 'x-portkey-azure-auth-mode', + AZURE_MANAGED_CLIENT_ID: 'x-portkey-azure-managed-client-id', + AZURE_ENTRA_TENANT_ID: 'x-portkey-azure-entra-tenant-id', + AZURE_ENTRA_CLIENT_ID: 'x-portkey-azure-entra-client-id', + AZURE_ENTRA_CLIENT_SECRET: 'x-portkey-azure-entra-client-secret', + CUSTOM_HOST: 'x-portkey-custom-host', + FORWARD_HEADERS: 'x-portkey-forward-headers', + DEFAULT_INPUT_GUARDRAILS: 'x-portkey-default-input-guardrails', + DEFAULT_OUTPUT_GUARDRAILS: 'x-portkey-default-output-guardrails', + AZURE_FOUNDRY_URL: 'x-portkey-azure-foundry-url', + AUDIO_FILE_DURATION: 'x-portkey-audio-file-duration', +}; + +export const INTERNAL_HEADER_KEYS: Record = { + CLIENT_AUTH_SECRET: 'x-client-auth-secret', +}; + +export const RESPONSE_HEADER_KEYS: Record = { + RETRY_ATTEMPT_COUNT: 'x-portkey-retry-attempt-count', +}; + +export const CONTENT_TYPES = { + APPLICATION_JSON: 'application/json', + MULTIPART_FORM_DATA: 'multipart/form-data', + EVENT_STREAM: 'text/event-stream', + GENERIC_AUDIO_PATTERN: 'audio/', + GENERIC_IMAGE_PATTERN: 'image/', + APPLICATION_OCTET_STREAM: 'application/octet-stream', + PLAIN_TEXT: 'text/plain', + HTML: 'text/html', +}; + +export const KV_PREFIX = { + orgConfig: 'ORG_CONFIG_', + cacheOrgProviderKeyPrefix: 'PROVIDER_KEY_V2_', + promptPrefix: 'PROMPT_V2_', + prompPartialPrefix: 'PROMPT_PARTIAL_', + virtualKeyUsagePrefix: 'VK_USAGE_', + guardrailPrefix: 'GUARDRAIL_', + integrationPrefix: 'INTEGRATION_', +}; + +export const providerAuthHeaderMap = { + [OPEN_AI]: 'authorization', + [COHERE]: 'authorization', + [ANTHROPIC]: 'x-api-key', + [ANYSCALE]: 'authorization', + [AZURE_OPEN_AI]: 'api-key', + [TOGETHER_AI]: 'authorization', +}; + +export const providerAuthHeaderPrefixMap = { + [OPEN_AI]: 'Bearer ', + [COHERE]: 'Bearer ', + [ANTHROPIC]: '', + [ANYSCALE]: 'Bearer ', + [AZURE_OPEN_AI]: '', + [TOGETHER_AI]: '', + [OLLAMA]: '', +}; + +export const MODES = { + PROXY_V2: 'proxy-2', // Latest proxy route: /v1/* + RUBEUS_V2: 'rubeus-2', // Latest rubeus routes: /v1/chat/completions, /v1/completions, /v1/embeddings and /v1/prompts + PROXY: 'proxy', // Deprecated proxy route /v1/proxy/* + RUBEUS: 'rubeus', // Deprecated rubeus routes: /v1/chatComplete, /v1/complete and /v1/embed + API: 'api', // Deprecated mode that is sent from /v1/prompts/:id/generate calls + REALTIME: 'realtime', +}; + +export const CACHE_STATUS = { + HIT: 'HIT', + SEMANTIC_HIT: 'SEMANTIC HIT', + MISS: 'MISS', + SEMANTIC_MISS: 'SEMANTIC MISS', + REFRESH: 'REFRESH', + DISABLED: 'DISABLED', +}; + +export enum AtomicCounterTypes { + COST = 'cost', + TOKENS = 'tokens', +} + +export enum AtomicOperations { + GET = 'GET', + RESET = 'RESET', + INCREMENT = 'INCREMENT', + DECREMENT = 'DECREMENT', +} + +export enum AtomicKeyTypes { + VIRTUAL_KEY = 'VIRTUAL_KEY', + API_KEY = 'API_KEY', + WORKSPACE = 'WORKSPACE', + INTEGRATION_WORKSPACE = 'INTEGRATION_WORKSPACE', +} + +export enum RateLimiterKeyTypes { + VIRTUAL_KEY = 'VIRTUAL_KEY', + API_KEY = 'API_KEY', + WORKSPACE = 'WORKSPACE', + INTEGRATION_WORKSPACE = 'INTEGRATION_WORKSPACE', +} + +export enum RateLimiterTypes { + REQUESTS = 'requests', + TOKENS = 'tokens', +} + +export enum CacheKeyTypes { + VIRTUAL_KEY = 'VKEY', + API_KEY = 'AKEY', + CONFIG = 'CONFIG', + PROMPT = 'PROMPT', + PROMPT_PARTIAL = 'PARTIAL', + ORGANISATION = 'ORG', + WORKSPACE = 'WS', + GUARDRAIL = 'GUARDRAIL', + INTEGRATIONS = 'INTEGRATIONS', +} + +export const RATE_LIMIT_UNIT_TO_WINDOW_MAPPING: Record = { + rpd: 86400000, + rph: 3600000, + rpm: 60000, + rps: 1000, +}; +export const cacheDisabledRoutesRegex = /\/v1\/audio\/.*/; + +export const GUARDRAIL_CATEGORIES: Record = { + BASIC: 'BASIC', + PARTNER: 'PARTNER', + PRO: 'PRO', +}; + +export const GUARDRAIL_CATEGORY_FLAG_MAP: Record = { + default: GUARDRAIL_CATEGORIES.BASIC, + portkey: GUARDRAIL_CATEGORIES.PRO, + pillar: GUARDRAIL_CATEGORIES.PARTNER, + patronus: GUARDRAIL_CATEGORIES.PARTNER, + aporia: GUARDRAIL_CATEGORIES.PARTNER, + sydelabs: GUARDRAIL_CATEGORIES.PARTNER, + mistral: GUARDRAIL_CATEGORIES.PARTNER, + pangea: GUARDRAIL_CATEGORIES.PARTNER, + promptfoo: GUARDRAIL_CATEGORIES.PARTNER, + bedrock: GUARDRAIL_CATEGORIES.PARTNER, + acuvity: GUARDRAIL_CATEGORIES.PARTNER, + azure: GUARDRAIL_CATEGORIES.PARTNER, + exa: GUARDRAIL_CATEGORIES.PARTNER, + lasso: GUARDRAIL_CATEGORIES.PARTNER, + promptsecurity: GUARDRAIL_CATEGORIES.PARTNER, + 'panw-prisma-airs': GUARDRAIL_CATEGORIES.PARTNER, +}; + +export enum EntityStatus { + ACTIVE = 'active', + EXHAUSTED = 'exhausted', + EXPIRED = 'expired', + ARCHIVED = 'archived', +} + +export enum HookTypePreset { + INPUT_GUARDRAILS = 'input_guardrails', + INPUT_MUTATORS = 'input_mutators', + BEFORE_REQUEST_HOOKS = 'before_request_hooks', + OUTPUT_MUTATORS = 'output_mutators', + OUTPUT_GUARDRAILS = 'output_guardrails', + AFTER_REQUEST_HOOKS = 'after_request_hooks', +} + +export const hookTypePresets = [ + HookTypePreset.INPUT_GUARDRAILS, + HookTypePreset.INPUT_MUTATORS, + HookTypePreset.BEFORE_REQUEST_HOOKS, + HookTypePreset.OUTPUT_MUTATORS, + HookTypePreset.OUTPUT_GUARDRAILS, + HookTypePreset.AFTER_REQUEST_HOOKS, +]; diff --git a/src/middlewares/portkey/handlers/albus.ts b/src/middlewares/portkey/handlers/albus.ts new file mode 100644 index 000000000..d74a2428f --- /dev/null +++ b/src/middlewares/portkey/handlers/albus.ts @@ -0,0 +1,453 @@ +import { AtomicCounterTypes, AtomicKeyTypes, CacheKeyTypes } from '../globals'; +import { WorkspaceDetails } from '../types'; +import { generateV2CacheKey } from './cache'; +import { fetchFromKVStore, putInKVStore } from './kv'; +import { fetchOrganisationProviderFromSlugFromFile } from './configFile'; + +const isLocalConfigEnabled = process.env.FETCH_SETTINGS_FROM_FILE === 'true'; + +/** + * Asynchronously fetch data from Albus. + * + * @param {string} url - The URL to fetch data from. + * @param {any} options - method and headers for the fetch request. + * @returns {Promise} - A Promise that resolves to the fetched data or null if an error occurs. + */ +const fetchFromAlbus = async ( + url: string, + options: any +): Promise => { + if (isLocalConfigEnabled) { + if (url.includes('/v2/virtual-keys/')) { + return fetchOrganisationProviderFromSlugFromFile(url); + } + } + try { + const response = await fetch(url, options); + + if (response.ok) { + const responseFromAlbus: any = await response.json(); + return responseFromAlbus; + } else { + console.log( + 'not found in albus', + url, + await response.clone().text(), + response.status, + JSON.stringify(options) + ); + } + } catch (error) { + console.log('error in fetching API Key from Albus', error); + } + + return null; +}; + +/** + * Asynchronously fetches virtual key details using the virtual key slug. + * + * @param {any} env - Hono environment object + * @param {string} orgApiKey - The API key for the organization. + * @param {string} organisationId - The ID of the organization. + * @param {string} providerKeySlug - The virtual key slugs. + * @returns {Promise} - A Promise that resolves to the fetched data or null if an error occurs. + */ +export const fetchOrganisationProviderFromSlug = async ( + env: any, + orgApiKey: string, + organisationId: string, + workspaceDetails: WorkspaceDetails, + providerKeySlug: string, + refetch?: boolean +): Promise => { + const cacheKey = generateV2CacheKey({ + organisationId, + workspaceId: workspaceDetails.id, + cacheKeyType: CacheKeyTypes.VIRTUAL_KEY, + key: providerKeySlug, + }); + if (!refetch) { + const responseFromKV = await fetchFromKVStore(env, cacheKey); + + if (responseFromKV) { + return responseFromKV; + } + } + + //check in albus, return and save in KV if found + const albusFetchUrl = `${env.ALBUS_BASEPATH}/v2/virtual-keys/${providerKeySlug}?organisation_id=${organisationId}&workspace_id=${workspaceDetails.id}`; + const albusFetchOptions = { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }; + + const responseFromAlbus = await fetchFromAlbus( + albusFetchUrl, + albusFetchOptions + ); + if (responseFromAlbus) { + await putInKVStore(env, cacheKey, responseFromAlbus); + } + + return responseFromAlbus; +}; + +export const updateOrganisationProviderKey = async ( + env: any, + orgApiKey: string, + organisationId: string, + workspaceDetails: WorkspaceDetails, + providerKeyId: string, + providerKeySlug: string, + updateObj: any +): Promise => { + //check in albus, return and save in KV if found + const albusFetchUrl = `${env.ALBUS_BASEPATH}/v2/virtual-keys/${providerKeyId}?organisation_id=${organisationId}&workspace_id=${workspaceDetails.id}`; + const albusFetchOptions = { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + body: JSON.stringify(updateObj), + }; + + const responseFromAlbus = await fetchFromAlbus( + albusFetchUrl, + albusFetchOptions + ); + if (responseFromAlbus) { + await fetchOrganisationProviderFromSlug( + env, + orgApiKey, + organisationId, + workspaceDetails, + providerKeySlug, + true + ); + } +}; + +/** + * Fetches organization configuration details based on the given parameters. + * + * @param {Object} env - Hono configuration object. + * @param {string} orgApiKey - Organisation portkey API key. + * @param {string} organisationId - Organisation ID. + * @param {string} configSlug - Config slug identifier. + * @returns {Promise} A Promise that resolves to the organization configuration details, + * or null if the configuration is not found or if it fails. + */ +export const fetchOrganisationConfig = async ( + env: any, + orgApiKey: string, + organisationId: string, + workspaceDetails: WorkspaceDetails, + configSlug: string +) => { + const cacheKey = generateV2CacheKey({ + organisationId, + workspaceId: workspaceDetails.id, + cacheKeyType: CacheKeyTypes.CONFIG, + key: configSlug, + }); + // fetch the config based on configSlug + let configDetailsFromKV = await fetchFromKVStore(env, cacheKey); + if (configDetailsFromKV) { + return { + organisationConfig: JSON.parse(configDetailsFromKV.config), + configVersion: configDetailsFromKV.version_id, + }; + } + + //fetch from albus + const albusUrl = `${env.ALBUS_BASEPATH}/v2/configs/${configSlug}?organisation_id=${organisationId}&workspace_id=${workspaceDetails.id}`; + const albusFetchOptions = { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }; + + const responseFromAlbus: any = await fetchFromAlbus( + albusUrl, + albusFetchOptions + ); + if (responseFromAlbus) { + // found in albus add to KV cache(async) and return + await putInKVStore(env, cacheKey, responseFromAlbus); + return { + organisationConfig: JSON.parse(responseFromAlbus.config), + configVersion: responseFromAlbus.version_id, + }; + } + + return null; +}; + +export const fetchOrganisationPrompt = async ( + env: any, + organisationId: string, + workspaceDetails: WorkspaceDetails, + apiKey: string, + promptSlug: string, + isCacheRefreshEnabled: boolean +) => { + //check in KV cache, return if found + const cacheKey = generateV2CacheKey({ + organisationId, + workspaceId: workspaceDetails.id, + cacheKeyType: CacheKeyTypes.PROMPT, + key: promptSlug, + }); + + if (!isCacheRefreshEnabled) { + const responseFromCache = await fetchFromKVStore(env, cacheKey); + if (responseFromCache) { + return responseFromCache; + } + } + + //check in albus, return and save in KV if found + const albusFetchUrl = `${env.ALBUS_BASEPATH}/v2/prompts/${promptSlug}?organisation_id=${organisationId}&workspace_id=${workspaceDetails.id}`; + const albusFetchOptions = { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }; + + const responseFromAlbus = await fetchFromAlbus( + albusFetchUrl, + albusFetchOptions + ); + if (responseFromAlbus) { + await putInKVStore(env, cacheKey, responseFromAlbus); + return responseFromAlbus; + } + return null; +}; + +export const fetchOrganisationPromptPartial = async ( + env: any, + organisationId: string, + workspaceDetails: WorkspaceDetails, + apiKey: string, + promptPartialSlug: string, + isCacheRefreshEnabled: boolean +) => { + //check in KV cache, return if found + const cacheKey = generateV2CacheKey({ + organisationId, + workspaceId: workspaceDetails.id, + cacheKeyType: CacheKeyTypes.PROMPT_PARTIAL, + key: promptPartialSlug, + }); + if (!isCacheRefreshEnabled) { + const responseFromCache = await fetchFromKVStore(env, cacheKey); + if (responseFromCache) { + return responseFromCache; + } + } + + //check in albus, return and save in KV if found + const albusFetchUrl = `${env.ALBUS_BASEPATH}/v2/prompts/partials/${promptPartialSlug}?organisation_id=${organisationId}&workspace_id=${workspaceDetails.id}`; + const albusFetchOptions = { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }; + + const responseFromAlbus = await fetchFromAlbus( + albusFetchUrl, + albusFetchOptions + ); + if (responseFromAlbus) { + await putInKVStore(env, cacheKey, responseFromAlbus); + return responseFromAlbus; + } + return null; +}; + +export const fetchOrganisationGuardrail = async ( + env: any, + orgId: string, + workspaceId: string | null, + guardrailSlug: string, + isCacheRefreshEnabled: boolean +) => { + //check in KV cache, return if found + const cacheKey = generateV2CacheKey({ + organisationId: orgId, + cacheKeyType: CacheKeyTypes.GUARDRAIL, + key: guardrailSlug, + workspaceId, + }); + + if (!isCacheRefreshEnabled) { + const responseFromCache = await fetchFromKVStore(env, cacheKey); + if (responseFromCache) { + return responseFromCache; + } + } + + //check in albus, return and save in KV if found + const albusFetchUrl = `${env.ALBUS_BASEPATH}/v2/guardrails/${guardrailSlug}?organisation_id=${orgId}${workspaceId ? `&workspace_id=${workspaceId}` : ''}`; + const albusFetchOptions = { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }; + + const responseFromAlbus = await fetchFromAlbus( + albusFetchUrl, + albusFetchOptions + ); + if (responseFromAlbus) { + await putInKVStore(env, cacheKey, responseFromAlbus); + return responseFromAlbus; + } + return null; +}; + +export const fetchOrganisationIntegrations = async ( + env: any, + orgId: string, + apiKey: string, + isCacheRefreshEnabled: boolean +) => { + //check in KV cache, return if found + const cacheKey = generateV2CacheKey({ + organisationId: orgId, + cacheKeyType: CacheKeyTypes.INTEGRATIONS, + key: 'all', + }); + + if (!isCacheRefreshEnabled) { + const responseFromCache = await fetchFromKVStore(env, cacheKey); + if (responseFromCache) { + return responseFromCache; + } + } + + //check in albus, return and save in KV if found + const albusFetchUrl = `${env.ALBUS_BASEPATH}/v2/integrations/?organisation_id=${orgId}`; + const albusFetchOptions = { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }; + + const responseFromAlbus = await fetchFromAlbus( + albusFetchUrl, + albusFetchOptions + ); + + if (responseFromAlbus) { + await putInKVStore(env, cacheKey, responseFromAlbus.data); + return responseFromAlbus.data; + } + return null; +}; + +export async function resyncOrganisationData({ + env, + organisationId, + apiKeysToReset, + apiKeysToExhaust, + apiKeysToExpire, + apiKeysToAlertThreshold, + apiKeysToUpdateUsage, + virtualKeyIdsToReset, + virtualKeyIdsToAlertThreshold, + virtualKeyIdsToExhaust, + virtualKeyIdsToExpire, + virtualKeyIdsToUpdateUsage, + keysToExpire, + keysToExhaust, + keysToAlertThreshold, + keysToUpdateUsage, + integrationWorkspacesToUpdateUsage, +}: { + env: Record; + organisationId: string; + apiKeysToReset?: string[]; + apiKeysToExhaust?: string[]; + apiKeysToExpire?: string[]; + apiKeysToAlertThreshold?: string[]; + apiKeysToUpdateUsage?: { id: string; usage: number }[]; + virtualKeyIdsToReset?: string[]; + virtualKeyIdsToAlertThreshold?: string[]; + virtualKeyIdsToExhaust?: string[]; + virtualKeyIdsToExpire?: string[]; + virtualKeyIdsToUpdateUsage?: { id: string; usage: number }[]; + keysToExpire?: { + key: string; + type: AtomicKeyTypes; + }[]; + keysToExhaust?: { + key: string; + type: AtomicKeyTypes; + counterType?: AtomicCounterTypes; + metadata?: Record; + usageLimitId?: string; + }[]; + keysToAlertThreshold?: { + key: string; + type: AtomicKeyTypes; + counterType?: AtomicCounterTypes; + metadata?: Record; + usageLimitId?: string; + }[]; + keysToUpdateUsage?: { + key: string; + type: AtomicKeyTypes; + counterType?: AtomicCounterTypes; + metadata?: Record; + usageLimitId?: string; + }[]; + integrationWorkspacesToUpdateUsage?: { + integration_id: string; + workspace_id: string; + usage: number; + }[]; +}) { + const path = `${env.ALBUS_BASEPATH}/v1/organisation/${organisationId}/resync`; + const options: RequestInit = { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: env.PORTKEY_CLIENT_AUTH, + }, + body: JSON.stringify({ + apiKeysToReset, + apiKeysToExhaust, + apiKeysToExpire, + apiKeysToAlertThreshold, + apiKeysToUpdateUsage, + virtualKeyIdsToReset, + virtualKeyIdsToExhaust, + virtualKeyIdsToExpire, + virtualKeyIdsToAlertThreshold, + virtualKeyIdsToUpdateUsage, + keysToExpire, + keysToExhaust, + keysToAlertThreshold, + keysToUpdateUsage, + integrationWorkspacesToUpdateUsage, + }), + }; + return fetch(path, options); +} diff --git a/src/middlewares/portkey/handlers/cache.ts b/src/middlewares/portkey/handlers/cache.ts new file mode 100644 index 000000000..36f16371f --- /dev/null +++ b/src/middlewares/portkey/handlers/cache.ts @@ -0,0 +1,157 @@ +import { CACHE_STATUS, CacheKeyTypes, HEADER_KEYS } from '../globals'; + +const getCacheMaxAgeFromHeaders = (maxAgeHeader: string) => { + try { + const maxAgeFromHeader = maxAgeHeader.match(/max-age=(\d+)/)?.[1]; + if (maxAgeFromHeader) return parseInt(maxAgeFromHeader); + } catch (err) { + console.log('invalid maxAgeHeader', err); + } + return null; +}; + +// Cache Handling +export const getFromCache = async ( + env: any, + requestHeaders: any, + requestBody: any, + url: string, + organisationId: string, + cacheMode: string, + cacheMaxAge: number | null +) => { + //forward request to Kreacher service binding + let maxAge: number | null = null; + const maxAgeHeader = requestHeaders[HEADER_KEYS.CACHE_CONTROL]; + if (cacheMaxAge) { + maxAge = cacheMaxAge; + } else if (maxAgeHeader) { + maxAge = getCacheMaxAgeFromHeaders(maxAgeHeader); + } + + const fetchBasePath = env.KREACHER_WORKER_BASEPATH; + if (!fetchBasePath) { + return [null, null, null]; + } + const fetchUrl = `${fetchBasePath}/get`; + //delete content-length header + delete requestHeaders['content-length']; + const fetchOptions = { + method: 'POST', + headers: requestHeaders, + body: JSON.stringify({ + headers: requestHeaders, + request: requestBody, + url: url, + organisationId, + cacheMode: cacheMode ?? requestHeaders[HEADER_KEYS.CACHE], + maxAge: maxAge, + }), + }; + + try { + const response = await env.kreacher.fetch(fetchUrl, fetchOptions); + if (response.status === 200) { + const responseFromKreacher = await response.json(); + + // This is a hack to handle the existing cache records which stored the error message due to 200 status code. + // It is currently not possible to invalidate these specific cache records so putting a check to return null for these records. + // TODO: Remove this check after around 60 days. + if (responseFromKreacher.data && responseFromKreacher.data.length < 100) { + const parsedData = JSON.parse(responseFromKreacher.data); + if ( + parsedData['html-message'] && + typeof parsedData['html-message'] === 'string' + ) { + return [null, CACHE_STATUS.MISS, responseFromKreacher.cacheKey]; + } + } + + if ( + [CACHE_STATUS.HIT, CACHE_STATUS.SEMANTIC_HIT].includes( + responseFromKreacher.status + ) + ) { + return [ + responseFromKreacher.data, + responseFromKreacher.status, + responseFromKreacher.cacheKey, + ]; + } else { + return [ + null, + responseFromKreacher.status, + responseFromKreacher.cacheKey, + ]; + } + } else { + return [null, null, null]; + } + } catch (error) { + console.log('Error in fetching from kreacher worker', error); + return [null, null, null]; + } +}; + +export const putInCache = async ( + env: any, + requestHeaders: any, + requestBody: any, + responseBody: any, + url: string, + organisationId: string, + cacheMode: string | null, + cacheMaxAge: number | null +) => { + let maxAge: number | null = null; + const maxAgeHeader = requestHeaders[HEADER_KEYS.CACHE_CONTROL]; + if (cacheMaxAge) { + maxAge = cacheMaxAge; + } else if (maxAgeHeader) { + maxAge = getCacheMaxAgeFromHeaders(maxAgeHeader); + } + + //forward request & response to Kreacher service binding + const fetchBasePath = env.KREACHER_WORKER_BASEPATH; + const fetchUrl = `${fetchBasePath}/put`; + const fetchOptions = { + method: 'POST', + headers: requestHeaders, + body: JSON.stringify({ + headers: requestHeaders, + request: requestBody, + response: responseBody, + url: url, + organisationId, + cacheMode: cacheMode ?? requestHeaders[HEADER_KEYS.CACHE], + maxAge: maxAge, + }), + }; + + try { + await env.kreacher.fetch(fetchUrl, fetchOptions); + } catch (error) { + console.log('Error in putting in kreacher worker', error); + } +}; + +export function generateV2CacheKey({ + organisationId, + workspaceId, + cacheKeyType, + key, +}: { + organisationId: string; + workspaceId?: string | null; + cacheKeyType: CacheKeyTypes; + key: string; +}) { + let cacheKey = `${cacheKeyType}_${key}_`; + if (organisationId) { + cacheKey += `${CacheKeyTypes.ORGANISATION}_${organisationId}`; + } + if (workspaceId) { + cacheKey += `${CacheKeyTypes.WORKSPACE}_${workspaceId}`; + } + return cacheKey; +} diff --git a/src/middlewares/portkey/handlers/configFile.ts b/src/middlewares/portkey/handlers/configFile.ts new file mode 100644 index 000000000..c5de98ed6 --- /dev/null +++ b/src/middlewares/portkey/handlers/configFile.ts @@ -0,0 +1,41 @@ +import { + settings, + defaultOrganisationDetails, +} from '../../../../initializeSettings'; + +export const fetchOrganisationProviderFromSlugFromFile = async ( + url: string +) => { + const virtualKeySlug = url.split('/').pop()?.split('?')[0]; + return settings.integrations.find( + (integration: any) => integration.slug === virtualKeySlug + ); +}; + +// not supported +// export const fetchOrganisationConfig = async () => { +// return fetchFromJson('organisationConfig'); +// }; + +// not supported +// export const fetchOrganisationPrompt = async () => { +// return fetchFromJson('organisationPrompt'); +// }; + +// not supported +// export const fetchOrganisationPromptPartial = async () => { +// return fetchFromJson('organisationPromptPartial'); +// }; + +// not supported +// export const fetchOrganisationGuardrail = async () => { +// return fetchFromJson('organisationGuardrail'); +// }; + +export const fetchOrganisationDetailsFromFile = async () => { + return settings?.organisationDetails ?? defaultOrganisationDetails; +}; + +export const fetchOrganisationIntegrationsFromFile = async () => { + // return settings.integrations; +}; diff --git a/src/middlewares/portkey/handlers/feedback.ts b/src/middlewares/portkey/handlers/feedback.ts new file mode 100644 index 000000000..65218c92f --- /dev/null +++ b/src/middlewares/portkey/handlers/feedback.ts @@ -0,0 +1,25 @@ +export interface FeedbackLogObject { + value: number; + metadata: Record; + weight: number; + trace_id: string; +} + +export async function addFeedback( + env: any, + feedbackObj: FeedbackLogObject, + headers: Record +) { + try { + await env.feedbackWorker.fetch(`${env.GATEWAY_BASEPATH}/v1/feedback`, { + method: 'POST', + body: JSON.stringify(feedbackObj), + headers: { + ...headers, + 'Content-Type': 'application/json', + }, + }); + } catch (error) { + console.error(error); + } +} diff --git a/src/middlewares/portkey/handlers/helpers.ts b/src/middlewares/portkey/handlers/helpers.ts new file mode 100644 index 000000000..871fe767b --- /dev/null +++ b/src/middlewares/portkey/handlers/helpers.ts @@ -0,0 +1,1823 @@ +import Mustache from '@portkey-ai/mustache'; + +import { getMode } from '..'; +import { + AZURE_AI, + AZURE_OPEN_AI, + BEDROCK, + EntityStatus, + GUARDRAIL_CATEGORY_FLAG_MAP, + HEADER_KEYS, + HookTypePreset, + hookTypePresets, + MODES, + OPEN_AI, + providerAuthHeaderMap, + providerAuthHeaderPrefixMap, + RATE_LIMIT_UNIT_TO_WINDOW_MAPPING, + RateLimiterKeyTypes, + SAGEMAKER, + VERTEX_AI, + WORKERS_AI, +} from '../globals'; +import { + fetchOrganisationConfig, + fetchOrganisationGuardrail, + fetchOrganisationIntegrations, + fetchOrganisationPrompt, + fetchOrganisationPromptPartial, + fetchOrganisationProviderFromSlug, +} from './albus'; +import { + BaseGuardrail, + OrganisationDetails, + RateLimit, + VirtualKeyDetails, + WorkspaceDetails, +} from '../types'; +import { constructAzureFoundryURL } from '../utils'; +import { + checkCircuitBreakerStatus, + CircuitBreakerContext, + extractCircuitBreakerConfigs, + generateCircuitBreakerConfigId, + getCircuitBreakerMappedConfig, +} from '../circuitBreaker'; + +const mapCustomHeaders = ( + customHeaders: Record, + headers: Headers +) => { + const headerKeys: string[] = []; + Object.entries(customHeaders).forEach(([key, value]) => { + if (key && typeof key === 'string' && typeof value === 'string') { + const _key = key.toLowerCase(); + headerKeys.push(_key); + headers.set(_key, value); + } + }); + return headerKeys; +}; + +const getVirtualKeyFromModel = (model: string | undefined) => { + if (!model) { + return null; + } + if ( + model.startsWith('@') && + // Cloudflare workers ai exception + !model.startsWith('@cf/') && + !model.startsWith('@hf/') + ) { + return { + virtualKey: model.slice(1).split('/')[0], + model: model.slice(1).split('/').slice(1).join('/'), + }; + } + return null; +}; + +export const getUniqueVirtualKeysFromConfig = (config: Record) => { + let uniqueVirtualKeys: Set = new Set(); + + function recursiveCollectKeysFromTarget( + currentConfig: Record, + configTargetType: string + ) { + if (!currentConfig[configTargetType]) { + const virtualKeyDetailsFromModel = getVirtualKeyFromModel( + currentConfig.override_params?.model + ); + let virtualKeyFromModel: string | undefined; + let mappedModelName: string | undefined; + if (virtualKeyDetailsFromModel) { + virtualKeyFromModel = virtualKeyDetailsFromModel.virtualKey; + mappedModelName = virtualKeyDetailsFromModel.model; + } + if (currentConfig.provider?.startsWith('@')) { + const virtualKey = currentConfig.provider.slice(1); + uniqueVirtualKeys.add(virtualKey); + currentConfig.virtual_key = virtualKey; + } else if (currentConfig.virtual_key) { + uniqueVirtualKeys.add(currentConfig.virtual_key); + } else if (virtualKeyFromModel && mappedModelName) { + const virtualKey = virtualKeyFromModel; + uniqueVirtualKeys.add(virtualKey); + currentConfig.virtual_key = virtualKey; + currentConfig.override_params.model = mappedModelName; + } + } + if (currentConfig[configTargetType]) { + for (const target of currentConfig[configTargetType]) { + recursiveCollectKeysFromTarget(target, configTargetType); + } + } + } + + const configTargetType = config.options?.length ? 'options' : 'targets'; + + recursiveCollectKeysFromTarget(config, configTargetType); + + return [...uniqueVirtualKeys]; +}; + +export const getUniqueGuardrailsFromConfig = (config: Record) => { + let uniqueGuardrails: Set = new Set(); + let rawHooksPresent = false; + + function recursiveCollectGuardrailsFromTarget( + currentConfig: Record + ) { + ['before_request_hooks', 'after_request_hooks'].forEach((hookType) => { + if (currentConfig[hookType]) { + rawHooksPresent = true; + currentConfig[hookType].forEach((h: any) => { + if (!h.checks) uniqueGuardrails.add(h.id); + }); + } + }); + + ['output_guardrails', 'input_guardrails'].forEach((guardrailType) => { + if (currentConfig[guardrailType]) { + rawHooksPresent = true; + currentConfig[guardrailType].forEach((h: any) => { + if (typeof h === 'string') uniqueGuardrails.add(h); + if (typeof h === 'object' && h.id?.startsWith('pg-')) { + uniqueGuardrails.add(h.id); + } + }); + } + }); + + if (currentConfig.targets) { + for (const target of currentConfig.targets) { + recursiveCollectGuardrailsFromTarget(target); + } + } + } + + recursiveCollectGuardrailsFromTarget(config); + return { uniqueGuardrails: [...uniqueGuardrails], rawHooksPresent }; +}; + +export const getVirtualKeyMap = async ( + env: any, + virtualKeyArr: Array, + orgApiKey: string, + organisationId: string, + workspaceDetails: WorkspaceDetails +) => { + const promises = virtualKeyArr.map(async (virtualKey) => { + const apiKeyKVRecord = await fetchOrganisationProviderFromSlug( + env, + orgApiKey, + organisationId, + workspaceDetails, + virtualKey + ); + return { + virtualKey, + apiKeyKVRecord, + }; + }); + + const results = await Promise.all(promises); + + const virtualKeyMap: Record = {}; + const missingKeys: Array = []; + results.forEach(({ virtualKey, apiKeyKVRecord }) => { + if (apiKeyKVRecord) { + virtualKeyMap[virtualKey] = { ...apiKeyKVRecord }; + } else { + missingKeys.push(virtualKey); + } + }); + + return { + virtualKeyMap, + missingKeys, + }; +}; + +export const getGuardrailMap = async ( + env: any, + guardrailSlugArr: BaseGuardrail[] +): Promise<{ + guardrailMap: Record; + missingGuardrails: string[]; +}> => { + const fetchPromises = guardrailSlugArr.map(async (guardrail) => { + const guardrailKVRecord = await fetchOrganisationGuardrail( + env, + guardrail.organisationId, + guardrail.workspaceId || null, + guardrail.slug, + false + ); + return { guardrailSlug: guardrail.slug, guardrailKVRecord }; + }); + + const results = await Promise.all(fetchPromises); + + const guardrailMap: Record = {}; + const missingGuardrails: Array = []; + + results.forEach(({ guardrailSlug, guardrailKVRecord }) => { + if (guardrailKVRecord) { + guardrailMap[guardrailSlug] = { ...guardrailKVRecord }; + } else { + missingGuardrails.push(guardrailSlug); + } + }); + + return { + guardrailMap, + missingGuardrails, + }; +}; + +/** + * Retrieves a map of prompts from KV store or albus in an async manner. + * + * @param {string} env - CF environment. + * @param {string[]} promptSlugArr - An array of prompt slugs. + * @param {string} orgApiKey - The organisation's API key. + * @param {string} organisationId - The organisation's ID. + * @returns {Promise<{ promptMap: Object, missingPrompts: string[] }>} A Promise resolving to an object containing the prompt map and an array of missing prompt_ids which are not found. + */ +export const getPromptMap = async ( + env: any, + promptSlugArr: Array, + orgApiKey: string, + organisationId: string, + workspaceDetails: WorkspaceDetails +): Promise<{ promptMap: Record; missingPrompts: string[] }> => { + const fetchPromises = promptSlugArr.map(async (promptSlug) => { + const promptKVRecord = await fetchOrganisationPrompt( + env, + organisationId, + workspaceDetails, + orgApiKey, + promptSlug, + false + ); + return { promptSlug, promptKVRecord }; + }); + + const results = await Promise.all(fetchPromises); + + const promptMap: Record = {}; + const missingPrompts: Array = []; + + results.forEach(({ promptSlug, promptKVRecord }) => { + if (promptKVRecord) { + promptMap[promptSlug] = { ...promptKVRecord }; + } else { + missingPrompts.push(promptSlug); + } + }); + + return { + promptMap, + missingPrompts, + }; +}; + +/** + * Retrieves a map of prompts from KV store or albus in an async manner. + * + * @param {string} env - CF environment. + * @param {string[]} promptPartialSlugArr - An array of prompt slugs. + * @param {string} orgApiKey - The organisation's API key. + * @param {string} organisationId - The organisation's ID. + * @returns {Promise<{ promptPartialMap: Record, missingPromptPartials: string[] }>} A Promise resolving to an object containing the prompt map and an array of missing prompt_ids which are not found. + */ +export const getPromptPartialMap = async ( + env: any, + promptPartialSlugArr: Array, + orgApiKey: string, + organisationId: string, + workspaceDetails: WorkspaceDetails +): Promise<{ + promptPartialMap: Record; + missingPromptPartials: string[]; +}> => { + const fetchPromises = promptPartialSlugArr.map( + async (promptPartialSlugArr) => { + const promptKVRecord = await fetchOrganisationPromptPartial( + env, + organisationId, + workspaceDetails, + orgApiKey, + promptPartialSlugArr, + false + ); + return { promptPartialSlugArr, promptKVRecord }; + } + ); + + const results = await Promise.all(fetchPromises); + + const promptPartialMap: Record = {}; + const missingPromptPartials: Array = []; + + results.forEach(({ promptPartialSlugArr, promptKVRecord }) => { + if (promptKVRecord) { + promptPartialMap[promptPartialSlugArr] = { ...promptKVRecord }; + } else { + missingPromptPartials.push(promptPartialSlugArr); + } + }); + + return { + promptPartialMap, + missingPromptPartials, + }; +}; + +/** + * Returns a config object with prompts mapped based on the provided prompt map and request body. + * + * @param {Object} promptMap - A map of prompts where key is prompt_id and value is a prompt data object. + * @param {Object} config - The original config object. + * @param {Object} requestBody - The request body. + * @returns {Object} A new config object with prompts mapped based on prompt_id. + */ +export const getPromptMappedConfig = ( + promptMap: Record, + promptPartialMap: Record, + config: Record, + requestBody: Record, + promptIDFromURL: string +) => { + const mappedConfig = { ...config }; + + function recursiveAddPromptsToTarget(currentConfig: Record) { + if ( + !currentConfig.targets && + currentConfig.prompt_id && + promptMap[currentConfig.prompt_id] + ) { + const { + provider_key_slug, + id: promptUUID, + prompt_version_id, + } = promptMap[currentConfig.prompt_id]; + currentConfig.virtual_key = + currentConfig.virtual_key ?? provider_key_slug; + currentConfig.prompt_uuid = promptUUID; + currentConfig.prompt_version_id = prompt_version_id; + const { requestBody: overrideParams } = createRequestFromPromptData( + {}, + promptMap[currentConfig.prompt_id], + promptPartialMap, + requestBody, + currentConfig.prompt_id + ); + currentConfig.override_params = { + ...overrideParams, + ...currentConfig.override_params, + }; + } else if (!currentConfig.targets && currentConfig.virtual_key) { + // If only virtual key is present in prompt config, then add the promptID from url. + const { id: promptUUID, prompt_version_id } = promptMap[promptIDFromURL]; + currentConfig.prompt_id = promptIDFromURL; + currentConfig.virtual_key = currentConfig.virtual_key; + currentConfig.prompt_uuid = promptUUID; + currentConfig.prompt_version_id = prompt_version_id; + const { requestBody: overrideParams } = createRequestFromPromptData( + {}, + promptMap[currentConfig.prompt_id], + promptPartialMap, + requestBody, + currentConfig.prompt_id + ); + currentConfig.override_params = { + ...overrideParams, + ...currentConfig.override_params, + }; + } + + if (currentConfig.targets) { + for (const target of currentConfig.targets) { + recursiveAddPromptsToTarget(target); + } + } + } + + recursiveAddPromptsToTarget(mappedConfig); + + return mappedConfig; +}; + +const getIntegrationCredentials = ( + integrationsMap: Array, + apiKey: string, + checkId: string, + incomingCredentials: Record | undefined +) => { + if (incomingCredentials && typeof incomingCredentials === 'object') { + return incomingCredentials; + } + const checkProvider = checkId.split('.')[0]; + const integration = integrationsMap.find( + (integration: any) => integration.integration_slug === checkProvider + ); + if (integration) { + return { ...integration.credentials }; + } + if (checkProvider === 'portkey') { + return { apiKey }; + } + + return {}; +}; + +const isCheckAllowed = (organisationSettings: any, checkId: string) => { + const checkProvider = checkId.split('.')[0]; + if ( + organisationSettings?.allowed_guardrails?.includes( + GUARDRAIL_CATEGORY_FLAG_MAP[checkProvider] + ) === false + ) { + return false; + } + + return true; +}; + +export const getGuardrailMappedConfig = ( + guardrailMap: Record, + config: Record, + integrationsMap: any, + apiKey: string, + organisationSettings: any +) => { + const mappedConfig = { ...config }; + + const addGuardrailToHook = (hook: any, isShorthand: boolean) => { + // If hook is a string, convert it to an object with id + if (isShorthand) { + const hookId = typeof hook === 'string' ? hook : hook.id; + if (!guardrailMap[hookId]) { + // If raw checks are sent in shorthand, then add the credentials and is_enabled + if (hook && typeof hook === 'object') { + Object.keys(hook) + .filter((key) => key.split('.').length === 2) + .forEach((key) => { + const integrationCredentials = getIntegrationCredentials( + integrationsMap, + apiKey, + key, + hook[key]?.credentials + ); + let isEnabled = + typeof hook[key]?.is_enabled === 'boolean' + ? hook[key]?.is_enabled + : true; + if (isEnabled) { + isEnabled = isCheckAllowed(organisationSettings, key); + } + hook[key] = { + ...hook[key], + credentials: integrationCredentials, + is_enabled: isEnabled, + }; + }); + } + return hook; + } + const { checks, actions, version_id } = guardrailMap[hookId]; + const checksObject = checks.reduce( + (acc: Record, check: any) => { + let isEnabled = + typeof check?.is_enabled === 'boolean' ? check?.is_enabled : true; + if (isEnabled) { + isEnabled = isCheckAllowed(organisationSettings, check.id); + } + acc[check.id] = { + ...check.parameters, + is_enabled: isEnabled, + credentials: getIntegrationCredentials( + integrationsMap, + apiKey, + check.id, + check.parameters?.credentials + ), + // Exclude id since it's now the key + id: undefined, + }; + return acc; + }, + {} + ); + + return { + id: hookId, + ...checksObject, + guardrail_version_id: version_id, + on_fail: hook.on_fail || actions.on_fail, + on_success: hook.on_success || actions.on_success, + deny: typeof hook.deny === 'boolean' ? hook.deny : actions.deny, + async: typeof hook.async === 'boolean' ? hook.async : actions.async, + }; + } + + const { checks, actions = {}, version_id } = guardrailMap[hook.id] || hook; + + Object.assign(hook, { + checks, + type: hook.type || 'guardrail', + guardrail_version_id: version_id, + on_fail: hook.on_fail || actions.on_fail, + on_success: hook.on_success || actions.on_success, + deny: typeof hook.deny === 'boolean' ? hook.deny : actions.deny, + async: typeof hook.async === 'boolean' ? hook.async : actions.async, + }); + + hook.checks?.forEach((check: any) => { + if (!check.parameters) { + check.parameters = {}; + } + let isEnabled = + typeof check?.is_enabled === 'boolean' ? check?.is_enabled : true; + if (isEnabled) { + isEnabled = isCheckAllowed(organisationSettings, check.id); + } + check.parameters.credentials = getIntegrationCredentials( + integrationsMap, + apiKey, + check.id, + check.parameters.credentials + ); + check.is_enabled = isEnabled; + }); + return hook; + }; + + const processHooks = (hooks: any[] = [], isShorthand: boolean = false) => { + // Map the hooks array and filter out any undefined values + return hooks + .map((hook) => addGuardrailToHook(hook, isShorthand)) + .filter(Boolean); + }; + + const recursiveAddGuardrailToHooks = (currentConfig: Record) => { + hookTypePresets.forEach((hookType) => { + if (currentConfig[hookType]) { + currentConfig[hookType] = processHooks( + currentConfig[hookType], + [ + HookTypePreset.INPUT_GUARDRAILS, + HookTypePreset.INPUT_MUTATORS, + HookTypePreset.OUTPUT_MUTATORS, + HookTypePreset.OUTPUT_GUARDRAILS, + ].includes(hookType) + ); + } + }); + + currentConfig.targets?.forEach(recursiveAddGuardrailToHooks); + }; + + recursiveAddGuardrailToHooks(mappedConfig); + + return mappedConfig; +}; + +/** + * Extracts unique prompt IDs from a nested/simple config object. + * + * @param {Object} config - The config object. + * @returns {Array} An array of unique prompt IDs. + */ +export const getUniquePromptSlugsFromConfig = ( + config: Record +): Array => { + const uniquePromptSlugs: Set = new Set(); + + function recursiveCollectPromptSlugsFromTarget( + currentConfig: Record + ) { + if (!currentConfig.targets && currentConfig.prompt_id) { + uniquePromptSlugs.add(currentConfig.prompt_id); + } + if (currentConfig.targets) { + for (const target of currentConfig.targets) { + recursiveCollectPromptSlugsFromTarget(target); + } + } + } + + recursiveCollectPromptSlugsFromTarget(config); + + return [...uniquePromptSlugs]; +}; + +/** + * Gets the config object with LLM API keys mapped based on the virtual keys present in it. + * + * @param {Record} virtualKeyMap - A mapping of virtual keys to its provider data. + * @param {Record} config - The original config object that needs to be mapped with API keys. + * @returns {Record} - The mapped config object with mapped virtual keys. + */ +export const getApiKeyMappedConfig = ( + virtualKeyMap: Record, + config: Record, + body: Record, + mappedHeaders: Headers +): Record => { + const mappedConfig = { ...config }; + + function recursiveAddKeysToTarget( + currentConfig: Record, + configTargetType: string + ) { + if ( + !currentConfig[configTargetType] && + currentConfig.virtual_key && + virtualKeyMap[currentConfig.virtual_key] + ) { + const { + key, + ai_provider_name, + model_config, + status, + usage_limits, + rate_limits, + integration_details: integrationDetails, + id, + } = virtualKeyMap[currentConfig.virtual_key]; + currentConfig.virtualKeyId = id; + if (integrationDetails) { + if (integrationDetails.status == EntityStatus.EXHAUSTED) { + currentConfig.isIntegrationExhausted = true; + } + currentConfig.integrationId = integrationDetails.id; + currentConfig.integrationSlug = integrationDetails.slug; + currentConfig.integrationModelDetails = { + allow_all_models: integrationDetails.allow_all_models, + models: integrationDetails.models || [], + }; + currentConfig.integrationUsageLimits = + integrationDetails.usage_limits || null; + currentConfig.integrationRateLimits = + integrationDetails.rate_limits || null; + } + if (status == EntityStatus.EXHAUSTED) { + currentConfig.isVirtualKeyExhausted = true; + } + if (status == EntityStatus.EXPIRED) { + currentConfig.isVirtualKeyExpired = true; + } + const vkUsageLimits = Array.isArray(usage_limits) + ? usage_limits + : usage_limits + ? [usage_limits] + : []; + currentConfig.virtualKeyUsageLimits = vkUsageLimits; + currentConfig.virtualKeyRateLimits = rate_limits || []; + currentConfig.provider = ai_provider_name; + if (currentConfig.provider !== VERTEX_AI) { + currentConfig.api_key = key; + } + + if (model_config?.customHost) { + currentConfig.custom_host = + currentConfig.custom_host ?? model_config.customHost; + } + + if (model_config?.customHeaders) { + const keys = mapCustomHeaders( + model_config.customHeaders, + mappedHeaders + ); + let forwardHeaders: string[] = currentConfig?.forward_headers || []; + forwardHeaders = [...forwardHeaders, ...keys]; + currentConfig.forward_headers = forwardHeaders; + } + + if (currentConfig.provider === AZURE_OPEN_AI && model_config) { + const { deployments } = model_config; + let deploymentConfig; + // Fetch override params for current config + const currentConfigOverrideParams = currentConfig.override_params; + const alias = currentConfigOverrideParams?.model ?? body['model']; + if (deployments) { + if (alias) { + deploymentConfig = deployments.find( + (_config: any) => _config.alias === alias + ); + } + if (!deploymentConfig) { + deploymentConfig = deployments.find( + (_config: any) => _config.is_default + ); + } + } + + if (deployments && !deploymentConfig) { + return { + status: 'failure', + message: 'No azure alias passed/default config found', + }; + } + currentConfig.resource_name = model_config.resourceName; + currentConfig.azure_auth_mode = model_config.azureAuthMode; + currentConfig.azure_managed_client_id = + model_config.azureManagedClientId; + currentConfig.azure_entra_client_id = model_config.azureEntraClientId; + currentConfig.azure_entra_client_secret = + model_config.azureEntraClientSecret; + currentConfig.azure_entra_tenant_id = model_config.azureEntraTenantId; + if (deploymentConfig) { + currentConfig.deployment_id = deploymentConfig.deploymentName; + currentConfig.api_version = deploymentConfig.apiVersion; + currentConfig.azure_model_name = deploymentConfig.aiModelName; + } else { + currentConfig.deployment_id = model_config.deploymentName; + currentConfig.api_version = model_config.apiVersion; + currentConfig.azure_model_name = model_config.aiModelName; + } + } + + if (currentConfig.provider === AZURE_AI && model_config) { + if (model_config.azureFoundryUrl) { + currentConfig.azure_foundry_url = model_config.azureFoundryUrl; + } else { + const foundryURL = constructAzureFoundryURL({ + azureDeploymentName: model_config?.azureDeploymentName, + azureDeploymentType: model_config?.azureDeploymentType, + azureEndpointName: model_config?.azureEndpointName, + azureRegion: model_config?.azureRegion, + }); + currentConfig.azure_foundry_url = foundryURL; + } + + currentConfig.azure_deployment_name = model_config.azureDeploymentName; + currentConfig.azure_api_version = model_config.azureApiVersion; + + currentConfig.azure_auth_mode = model_config.azureAuthMode; + currentConfig.azure_managed_client_id = + model_config.azureManagedClientId; + currentConfig.azure_entra_client_id = model_config.azureEntraClientId; + currentConfig.azure_entra_client_secret = + model_config.azureEntraClientSecret; + currentConfig.azure_entra_tenant_id = model_config.azureEntraTenantId; + } + + if ( + [BEDROCK, SAGEMAKER].includes(currentConfig.provider) && + model_config + ) { + const { + awsAuthType, + awsAccessKeyId, + awsSecretAccessKey, + awsRegion, + awsRoleArn, + awsExternalId, + } = model_config; + currentConfig.aws_auth_type = awsAuthType; + currentConfig.aws_secret_access_key = awsSecretAccessKey; + currentConfig.aws_region = awsRegion; + currentConfig.aws_access_key_id = awsAccessKeyId; + currentConfig.aws_role_arn = awsRoleArn; + currentConfig.aws_external_id = awsExternalId; + } + + if (currentConfig.provider === SAGEMAKER && model_config) { + const { + amznSagemakerCustomAttributes, + amznSagemakerTargetModel, + amznSagemakerTargetVariant, + amznSagemakerTargetContainerHostname, + amznSagemakerInferenceId, + amznSagemakerEnableExplanations, + amznSagemakerInferenceComponent, + amznSagemakerSessionId, + amznSagemakerModelName, + } = model_config; + currentConfig.amzn_sagemaker_custom_attributes = + currentConfig.amzn_sagemaker_custom_attributes || + amznSagemakerCustomAttributes; + currentConfig.amzn_sagemaker_target_model = + currentConfig.amzn_sagemaker_target_model || amznSagemakerTargetModel; + currentConfig.amzn_sagemaker_target_variant = + currentConfig.amzn_sagemaker_target_variant || + amznSagemakerTargetVariant; + currentConfig.amzn_sagemaker_target_container_hostname = + currentConfig.amzn_sagemaker_target_container_hostname || + amznSagemakerTargetContainerHostname; + currentConfig.amzn_sagemaker_inference_id = + currentConfig.amzn_sagemaker_inference_id || amznSagemakerInferenceId; + currentConfig.amzn_sagemaker_enable_explanations = + currentConfig.amzn_sagemaker_enable_explanations || + amznSagemakerEnableExplanations; + currentConfig.amzn_sagemaker_inference_component = + currentConfig.amzn_sagemaker_inference_component || + amznSagemakerInferenceComponent; + currentConfig.amzn_sagemaker_sessionId = + currentConfig.amzn_sagemaker_sessionId || amznSagemakerSessionId; + currentConfig.amzn_sagemaker_model_name = + currentConfig.amzn_sagemaker_model_name || amznSagemakerModelName; + } + + if (currentConfig.provider === VERTEX_AI && model_config) { + currentConfig.vertex_project_id = model_config.vertexProjectId; + currentConfig.vertex_region = + currentConfig.vertex_region || model_config.vertexRegion; + currentConfig.vertex_service_account_json = + model_config.vertexServiceAccountJson; + } + + if (currentConfig.provider === WORKERS_AI && model_config) { + currentConfig.workers_ai_account_id = model_config.workersAiAccountId; + } + + if (currentConfig.provider === OPEN_AI && model_config) { + if (model_config.openaiOrganization) { + currentConfig.openai_organization = model_config.openaiOrganization; + } + + if (model_config.openaiProject) { + currentConfig.openai_project = model_config.openaiProject; + } + } + } + + if (currentConfig[configTargetType]) { + for (const target of currentConfig[configTargetType]) { + recursiveAddKeysToTarget(target, configTargetType); + } + } + } + + const configTargetType = mappedConfig.options?.length ? 'options' : 'targets'; + + recursiveAddKeysToTarget(mappedConfig, configTargetType); + + return mappedConfig; +}; + +export const getConfigDetailsFromRequest = ( + requestHeaders: Record, + requestBody: Record, + path: string +) => { + const mode = getMode(requestHeaders, path); + const isHeaderConfigEnabledRequest = [ + MODES.PROXY, + MODES.PROXY_V2, + MODES.RUBEUS_V2, + ].includes(mode); + + const isBodyConfigEnabledRequest = [MODES.RUBEUS].includes(mode); + + const configHeader = requestHeaders.get(HEADER_KEYS.CONFIG); + + if (isBodyConfigEnabledRequest) { + if (typeof requestBody.config === 'string') { + return { + type: 'slug', + data: requestBody.config, + }; + } else if (typeof requestBody.config === 'object') { + return { + type: 'object', + data: requestBody.config, + }; + } + } + + if (isHeaderConfigEnabledRequest && configHeader) { + if (configHeader.startsWith('pc-')) { + return { + type: 'slug', + data: requestHeaders.get(HEADER_KEYS.CONFIG), + }; + } else { + try { + const parsedConfigJSON = JSON.parse(configHeader); + return { + type: 'object', + data: parsedConfigJSON, + }; + } catch (e) { + console.log('invalid config', e); + } + } + } + + return null; +}; + +// This function can be extended to do any further config mapping. +// Currently, this only replaces virtual key. +export const getMappedConfig = async ( + env: any, + config: Record, + apiKey: string, + organisationDetails: OrganisationDetails, + requestBody: Record, + path: string, + isVirtualKeyUsageEnabled: boolean, + mappedHeaders: Headers, + defaultGuardrails: BaseGuardrail[], + configSlug: string +) => { + const { + id: organisationId, + workspaceDetails, + settings: organisationSettings, + } = organisationDetails; + + // Create circuit breaker context + let circuitBreakerContext: CircuitBreakerContext | null = null; + let updatedCircuitBreakerContext: CircuitBreakerContext | null = null; + if (configSlug) { + const configId = generateCircuitBreakerConfigId( + configSlug, + workspaceDetails.id, + organisationId + ); + circuitBreakerContext = extractCircuitBreakerConfigs(config, configId); + // Check circuit breaker status + updatedCircuitBreakerContext = await checkCircuitBreakerStatus( + env, + circuitBreakerContext + ); + } + + let promptIDFromURL: string = ''; + + const isPromptCompletionsCall = path.startsWith('/v1/prompts/') + ? true + : false; + if (isPromptCompletionsCall) { + promptIDFromURL = path.split('/')[3]; + } + const promptSlugArr = getUniquePromptSlugsFromConfig(config); + + const isPromptCompletionsConfig = promptSlugArr.length > 0 ? true : false; + + if (promptIDFromURL && !promptSlugArr.includes(promptIDFromURL)) { + promptSlugArr.push(promptIDFromURL); + } + + // configs with prompt_id are only allowed in /v1/prompts route + if (!isPromptCompletionsCall && promptSlugArr.length) { + return { + status: 'error', + message: `You cannot pass config with prompt id in /v1/prompts route`, + }; + } + const { promptMap, missingPrompts } = await getPromptMap( + env, + promptSlugArr, + apiKey, + organisationId, + workspaceDetails + ); + if (missingPrompts.length > 0) { + return { + status: 'error', + message: `Following prompt_id are not valid: ${missingPrompts.join( + ', ' + )}`, + }; + } + + const { missingVariablePartials, uniquePartials } = + getUniquePromptPartialsFromPromptMap(promptMap, requestBody); + if (missingVariablePartials.length > 0) { + return { + status: 'error', + message: `Missing variable partials: ${missingVariablePartials.join( + ', ' + )}`, + }; + } + const { promptPartialMap, missingPromptPartials } = await getPromptPartialMap( + env, + uniquePartials, + apiKey, + organisationId, + workspaceDetails + ); + if (missingPromptPartials.length > 0) { + return { + status: 'error', + message: `Missing prompt partials: ${missingPromptPartials.join(', ')}`, + }; + } + + const promptMappedConfig = isPromptCompletionsConfig + ? getPromptMappedConfig( + promptMap, + promptPartialMap, + config, + requestBody, + promptIDFromURL + ) + : config; + + const virtualKeyArr = getUniqueVirtualKeysFromConfig(promptMappedConfig); + const { virtualKeyMap, missingKeys } = await getVirtualKeyMap( + env, + virtualKeyArr, + apiKey, + organisationId, + workspaceDetails + ); + if (missingKeys.length > 0) { + return { + status: 'error', + message: `Following keys are not valid: ${missingKeys.join(', ')}`, + }; + } + + const { uniqueGuardrails: guardrailKeyArr, rawHooksPresent } = + getUniqueGuardrailsFromConfig(promptMappedConfig); + const mappedGuardrailKeyArr: BaseGuardrail[] = guardrailKeyArr.map( + (guardrail) => ({ + slug: guardrail, + organisationId, + workspaceId: workspaceDetails.id, + }) + ); + + let integrations; + if ( + guardrailKeyArr.length > 0 || + rawHooksPresent || + defaultGuardrails.length > 0 + ) { + integrations = await fetchOrganisationIntegrations( + env, + organisationId, + apiKey, + false + ); + } + + defaultGuardrails.forEach((guardrail) => { + if (!guardrailKeyArr.includes(guardrail.slug)) { + mappedGuardrailKeyArr.push(guardrail); + } + }); + + const { guardrailMap, missingGuardrails } = await getGuardrailMap( + env, + mappedGuardrailKeyArr + ); + + if (missingGuardrails.length > 0) { + return { + status: 'error', + message: `Following guardrails are not valid: ${missingGuardrails.join( + ', ' + )}`, + }; + } + + let promptRequestURL = ''; + // For /v1/prompts with prompt_id configs, generate a request url based on modelType + if ( + isPromptCompletionsConfig && + Object.values(promptMap)[0]?.ai_model_type === 'chat' + ) { + promptRequestURL = `${env.GATEWAY_BASEPATH}/chat/completions`; + } else if ( + isPromptCompletionsConfig && + Object.values(promptMap)[0]?.ai_model_type === 'text' + ) { + promptRequestURL = `${env.GATEWAY_BASEPATH}/completions`; + } + + const guardrailMappedConfig = getGuardrailMappedConfig( + guardrailMap, + config, + integrations, + apiKey, + organisationSettings + ); + + if (Object.keys(virtualKeyMap).length === 0) { + return { + status: 'success', + data: guardrailMappedConfig, + promptRequestURL: promptRequestURL, + guardrailMap: guardrailMap, + integrations: integrations, + circuitBreakerContext: null, + }; + } + + const apiKeyMappedConfig = getApiKeyMappedConfig( + virtualKeyMap, + guardrailMappedConfig, + requestBody, + mappedHeaders + ); + // Apply circuit breaker status to config + const circuitBreakerMappedConfig = updatedCircuitBreakerContext + ? getCircuitBreakerMappedConfig( + apiKeyMappedConfig, + updatedCircuitBreakerContext + ) + : apiKeyMappedConfig; + + return { + status: 'success', + data: circuitBreakerMappedConfig, + promptRequestURL: promptRequestURL, + guardrailMap: guardrailMap, + integrations: integrations, + circuitBreakerContext: updatedCircuitBreakerContext, + }; +}; + +export const getUniquePromptPartialsFromPromptMap = ( + promptMap: Record, + requestBodyJSON: Record +): { + missingVariablePartials: string[]; + uniquePartials: string[]; +} => { + const missingVariablePartials: string[] = []; + const uniquePartials = new Set(); + + // Loop over each prompt in the map + Object.values(promptMap).forEach((promptData) => { + if (promptData.variable_components) { + // Safely parse variableComponents, which can contain various properties + const components = JSON.parse(promptData.variable_components); + + // Check and add 'partials' if they exist + if (Array.isArray(components.partials)) { + components.partials.forEach((partial: string) => + uniquePartials.add(partial) + ); + } + + // Check and add 'variablePartials' if they exist. Variable partial values are resolved from the request body + if (Array.isArray(components.variablePartials)) { + components.variablePartials.forEach((eachVariablePartial: string) => { + if ( + requestBodyJSON.variables && + requestBodyJSON.variables.hasOwnProperty(eachVariablePartial) + ) { + uniquePartials.add(requestBodyJSON.variables[eachVariablePartial]); + } else { + missingVariablePartials.push(eachVariablePartial); + } + }); + } + } + }); + + // Convert the Set to an array to return the unique partials + return { + missingVariablePartials: [...missingVariablePartials], + uniquePartials: [...uniquePartials], + }; +}; + +export const getMappedConfigFromRequest = async ( + env: any, + requestBody: Record, + requestHeaders: Headers, + orgApiKey: string, + organisationDetails: OrganisationDetails, + path: string, + isVirtualKeyUsageEnabled: boolean, + mappedHeaders: Headers, + defaultGuardrails: BaseGuardrail[] +) => { + const { id: organisationId, workspaceDetails } = organisationDetails; + const configDetails = getConfigDetailsFromRequest( + requestHeaders, + requestBody, + path + ); + if (!configDetails) { + let guardrailMap; + let integrations; + if (defaultGuardrails.length > 0) { + const guardrails = await getGuardrailMap(env, defaultGuardrails); + + guardrailMap = guardrails.guardrailMap; + if (guardrails.missingGuardrails.length > 0) { + return { + status: 'failure', + message: `Following default guardrails are not valid: ${guardrails.missingGuardrails.join( + ', ' + )}`, + }; + } + integrations = await fetchOrganisationIntegrations( + env, + organisationId, + orgApiKey, + false + ); + } + return { + status: 'success', + mappedConfig: null, + configVersion: null, + configSlug: null, + promptRequestURL: null, + guardrailMap: guardrailMap, + integrations: integrations, + }; + } + const store: Record = {}; + if (configDetails.type === 'slug') { + const orgConfigFromSlug = await fetchOrganisationConfig( + env, + orgApiKey, + organisationId, + workspaceDetails, + configDetails.data + ); + if (!orgConfigFromSlug) { + return { + status: 'failure', + message: 'Invalid config id passed', + }; + } + store.organisationConfig = { + ...orgConfigFromSlug.organisationConfig, + }; + store.configVersion = orgConfigFromSlug.configVersion; + store.configSlug = configDetails.data; + } else if (configDetails.type === 'object') { + store.organisationConfig = { + ...configDetails.data, + }; + } + + const mappedConfig = await getMappedConfig( + env, + store.organisationConfig, + orgApiKey, + organisationDetails, + requestBody, + path, + isVirtualKeyUsageEnabled, + mappedHeaders, + defaultGuardrails, + store.configSlug + ); + if (mappedConfig.status === 'error') { + return { + status: 'failure', + message: mappedConfig.message, + }; + } + + return { + status: 'success', + mappedConfig: mappedConfig.data, + configVersion: store.configVersion, + configSlug: store.configSlug, + promptRequestURL: mappedConfig.promptRequestURL, + guardrailMap: mappedConfig.guardrailMap, + integrations: mappedConfig.integrations, + circuitBreakerContext: mappedConfig.circuitBreakerContext, + }; +}; + +/** + * Handles the mapping of virtual key header to provider and authorization header + * + * @param {Object} env - Hono environment object. + * @param {string} orgApiKey - The organization's API key. + * @param {string} organisationId - The organization's ID. + * @param {Headers} headers - Original request headers object + * @param {string} mode - The mode for the request. Decided on the basis of route that is called + * @returns {Promise<{status: string, message?: string}>} - A promise resolving to an object + * with the status success/failure and an optional message in case of failure. + */ +export const handleVirtualKeyHeader = async ( + env: any, + orgApiKey: string, + organisationId: string, + workspaceDetails: WorkspaceDetails, + headers: Headers, + mode: string, + requestBody: Record, + mappedURL: string +): Promise<{ status: string; message?: string }> => { + const virtualKeyDetailsFromModel = getVirtualKeyFromModel(requestBody?.model); + let virtualKeyFromModel: string | undefined; + let mappedModelName: string | undefined; + if (virtualKeyDetailsFromModel) { + virtualKeyFromModel = virtualKeyDetailsFromModel.virtualKey; + mappedModelName = virtualKeyDetailsFromModel.model; + } + const virtualKey = headers.get(HEADER_KEYS.PROVIDER)?.startsWith('@') + ? headers.get(HEADER_KEYS.PROVIDER)?.slice(1).split('/')[0] + : headers.get(HEADER_KEYS.VIRTUAL_KEY) + ? headers.get(HEADER_KEYS.VIRTUAL_KEY) + : virtualKeyFromModel + ? virtualKeyFromModel + : null; + + if (!virtualKey) { + return { + status: 'success', + }; + } + + if (requestBody?.model && mappedModelName && virtualKeyFromModel) { + requestBody.model = mappedModelName; + } + + headers.set(HEADER_KEYS.VIRTUAL_KEY, virtualKey); + const apiKeyKVRecord = await fetchOrganisationProviderFromSlug( + env, + orgApiKey, + organisationId, + workspaceDetails, + virtualKey + ); + if (!apiKeyKVRecord) { + return { + status: 'failure', + message: `Following keys are not valid: ${headers.get( + HEADER_KEYS.VIRTUAL_KEY + )}`, + }; + } + + const { + id: virtualKeyId, + ai_provider_name, + key, + model_config, + status, + usage_limits, + rate_limits: rateLimits, + slug, + expires_at, + integration_details: integrationDetails, + } = apiKeyKVRecord; + + if (integrationDetails) { + headers.set( + HEADER_KEYS.INTEGRATION_DETAILS, + JSON.stringify({ + id: integrationDetails.id, + slug: integrationDetails.slug, + status: integrationDetails.status, + allow_all_models: integrationDetails.allow_all_models, + models: integrationDetails.models || [], + usage_limits: integrationDetails.usage_limits || [], + rate_limits: integrationDetails.rate_limits || [], + }) + ); + } + + const usageLimits = Array.isArray(usage_limits) + ? usage_limits + : usage_limits + ? [usage_limits] + : []; + + const virtualKeyDetails: VirtualKeyDetails = { + status, + usage_limits: usageLimits || [], + rate_limits: rateLimits || [], + id: virtualKeyId, + workspace_id: workspaceDetails.id, + slug, + organisation_id: organisationId, + expires_at: expires_at, + }; + + headers.set( + HEADER_KEYS.VIRTUAL_KEY_DETAILS, + JSON.stringify(virtualKeyDetails) + ); + + const currentCustomHost = headers.get(HEADER_KEYS.CUSTOM_HOST); + if (model_config?.customHost && !currentCustomHost) { + headers.set(HEADER_KEYS.CUSTOM_HOST, model_config.customHost); + } + + if ( + providerAuthHeaderMap[ai_provider_name] && + ![MODES.RUBEUS_V2, MODES.REALTIME].includes(mode) + ) { + headers.set( + providerAuthHeaderMap[ai_provider_name], + `${providerAuthHeaderPrefixMap[ai_provider_name]}${key}` + ); + } else if (ai_provider_name !== VERTEX_AI) { + headers.set('authorization', `Bearer ${key}`); + } + + // Order is important, headers can contain `authorization` which should be overriden if custom LLM. + if (model_config?.customHeaders) { + const allkeys = mapCustomHeaders(model_config.customHeaders, headers); + // Get current forward headers + const currentForwardHeaders = headers.get(HEADER_KEYS.FORWARD_HEADERS); + + const finalHeaders = + allkeys.join(',') + + (currentForwardHeaders ? `,${currentForwardHeaders}` : ''); + // Set forward headers + headers.set(HEADER_KEYS.FORWARD_HEADERS, finalHeaders); + } + + // Azure OpenAI requires `Authorization` header for finetuning & files routes. + if ( + ai_provider_name === AZURE_OPEN_AI && + (mappedURL.includes('/fine_tuning') || mappedURL.includes('/files')) + ) { + headers.set('authorization', `Bearer ${key}`); + } + + if (ai_provider_name === AZURE_OPEN_AI && model_config) { + const { + resourceName, + deploymentName, + apiVersion, + aiModelName, + azureAuthMode, + azureManagedClientId, + azureEntraTenantId, + azureEntraClientId, + azureEntraClientSecret, + deployments, + } = model_config; + const azureDeploymentAlias = requestBody['model']; + // New Config + let deploymentConfig; + if (deployments) { + if (azureDeploymentAlias) { + deploymentConfig = deployments.find( + (_config: any) => _config.alias === azureDeploymentAlias + ); + } + // Fallback to default model in-case we didn't fine any alias in the vk config. + if (!deploymentConfig) { + deploymentConfig = deployments.find( + (_config: any) => _config.is_default + ); + } + } + + if (deployments && !deploymentConfig) { + return { + status: 'failure', + message: 'No azure alias passed/default config found', + }; + } + + headers.set(HEADER_KEYS.AZURE_RESOURCE, resourceName?.toString() ?? ''); + headers.set(HEADER_KEYS.AZURE_AUTH_MODE, azureAuthMode ?? ''); + headers.set( + HEADER_KEYS.AZURE_MANAGED_CLIENT_ID, + azureManagedClientId ?? '' + ); + headers.set(HEADER_KEYS.AZURE_ENTRA_CLIENT_ID, azureEntraClientId ?? ''); + headers.set( + HEADER_KEYS.AZURE_ENTRA_CLIENT_SECRET, + azureEntraClientSecret ?? '' + ); + headers.set(HEADER_KEYS.AZURE_ENTRA_TENANT_ID, azureEntraTenantId ?? ''); + if (deploymentConfig) { + headers.set( + HEADER_KEYS.AZURE_DEPLOYMENT, + deploymentConfig.deploymentName?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.AZURE_API_VERSION, + deploymentConfig.apiVersion?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.AZURE_MODEL_NAME, + deploymentConfig.aiModelName?.toString() ?? '' + ); + } else { + headers.set( + HEADER_KEYS.AZURE_DEPLOYMENT, + deploymentName?.toString() ?? '' + ); + headers.set(HEADER_KEYS.AZURE_API_VERSION, apiVersion?.toString() ?? ''); + headers.set(HEADER_KEYS.AZURE_MODEL_NAME, aiModelName?.toString() ?? ''); + } + } + + if (ai_provider_name === AZURE_AI) { + if (model_config?.azureFoundryUrl) { + headers.set( + HEADER_KEYS.AZURE_FOUNDRY_URL, + model_config.azureFoundryUrl ?? '' + ); + } else { + const foundryURL = constructAzureFoundryURL({ + azureDeploymentName: model_config?.azureDeploymentName, + azureDeploymentType: model_config?.azureDeploymentType, + azureEndpointName: model_config?.azureEndpointName, + azureRegion: model_config?.azureRegion, + }); + headers.set(HEADER_KEYS.AZURE_FOUNDRY_URL, foundryURL ?? ''); + headers.set( + HEADER_KEYS.AZURE_DEPLOYMENT_NAME, + model_config?.azureDeploymentName ?? '' + ); + } + headers.set( + HEADER_KEYS.AZURE_API_VERSION, + model_config?.azureApiVersion ?? '' + ); + headers.set(HEADER_KEYS.AZURE_AUTH_MODE, model_config?.azureAuthMode ?? ''); + headers.set( + HEADER_KEYS.AZURE_MANAGED_CLIENT_ID, + model_config?.azureManagedClientId ?? '' + ); + headers.set( + HEADER_KEYS.AZURE_ENTRA_CLIENT_ID, + model_config?.azureEntraClientId ?? '' + ); + headers.set( + HEADER_KEYS.AZURE_ENTRA_CLIENT_SECRET, + model_config?.azureEntraClientSecret ?? '' + ); + headers.set( + HEADER_KEYS.AZURE_ENTRA_TENANT_ID, + model_config?.azureEntraTenantId ?? '' + ); + } + + if ([BEDROCK, SAGEMAKER].includes(ai_provider_name) && model_config) { + const { + awsAuthType, + awsAccessKeyId, + awsSecretAccessKey, + awsRegion, + awsRoleArn, + awsExternalId, + } = model_config; + + headers.set(HEADER_KEYS.AWS_AUTH_TYPE, awsAuthType?.toString() ?? ''); + headers.set(HEADER_KEYS.AWS_ROLE_ARN, awsRoleArn?.toString() ?? ''); + headers.set(HEADER_KEYS.AWS_EXTERNAL_ID, awsExternalId?.toString() ?? ''); + + headers.set( + HEADER_KEYS.BEDROCK_ACCESS_KEY_ID, + awsAccessKeyId?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.BEDROCK_SECRET_ACCESS_KEY, + awsSecretAccessKey?.toString() ?? '' + ); + headers.set(HEADER_KEYS.BEDROCK_REGION, awsRegion?.toString() ?? ''); + headers.delete('authorization'); + } + + if (ai_provider_name === SAGEMAKER && model_config) { + const { + amznSagemakerCustomAttributes, + amznSagemakerTargetModel, + amznSagemakerTargetVariant, + amznSagemakerTargetContainerHostname, + amznSagemakerInferenceId, + amznSagemakerEnableExplanations, + amznSagemakerInferenceComponent, + amznSagemakerSessionId, + amznSagemakerModelName, + } = model_config; + headers.set( + HEADER_KEYS.SAGEMAKER_CUSTOM_ATTRIBUTES, + amznSagemakerCustomAttributes?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.SAGEMAKER_ENABLE_EXPLANATIONS, + amznSagemakerEnableExplanations?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.SAGEMAKER_INFERENCE_COMPONENT, + amznSagemakerInferenceComponent?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.SAGEMAKER_INFERENCE_ID, + amznSagemakerInferenceId?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.SAGEMAKER_SESSION_ID, + amznSagemakerSessionId?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.SAGEMAKER_TARGET_CONTAINER_HOSTNAME, + amznSagemakerTargetContainerHostname?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.SAGEMAKER_TARGET_MODEL, + amznSagemakerTargetModel?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.SAGEMAKER_TARGET_VARIANT, + amznSagemakerTargetVariant?.toString() ?? '' + ); + headers.set( + HEADER_KEYS.SAGEMAKER_MODEL_NAME, + amznSagemakerModelName?.toString() ?? '' + ); + } + + if (ai_provider_name === VERTEX_AI && model_config) { + const { vertexProjectId, vertexRegion, vertexServiceAccountJson } = + model_config; + + headers.set( + HEADER_KEYS.VERTEX_AI_PROJECT_ID, + vertexProjectId?.toString() ?? '' + ); + + if (!headers.get(HEADER_KEYS.VERTEX_AI_REGION)) { + headers.set(HEADER_KEYS.VERTEX_AI_REGION, vertexRegion?.toString() ?? ''); + } + + if (vertexServiceAccountJson) { + headers.set( + HEADER_KEYS.VERTEX_SERVICE_ACCOUNT_JSON, + JSON.stringify(vertexServiceAccountJson) + ); + } + } + + if (ai_provider_name === WORKERS_AI && model_config) { + const { workersAiAccountId } = model_config; + + headers.set( + HEADER_KEYS.WORKERS_AI_ACCOUNT_ID, + workersAiAccountId?.toString() ?? '' + ); + } + + if (ai_provider_name === OPEN_AI && model_config) { + const { openaiOrganization, openaiProject } = model_config; + + if (openaiOrganization) { + headers.set( + HEADER_KEYS.OPEN_AI_ORGANIZATION, + openaiOrganization?.toString() ?? '' + ); + } + + if (openaiProject) { + headers.set(HEADER_KEYS.OPEN_AI_PROJECT, openaiProject?.toString() ?? ''); + } + } + + headers.set(HEADER_KEYS.PROVIDER, ai_provider_name); + + return { + status: 'success', + }; +}; + +export const createRequestFromPromptData = ( + env: any, + promptData: Record, + promptPartialMap: Record, + requestBodyJSON: Record, + promptSlug: string +) => { + let promptString = promptData.string; + + const jsonCompatibleVariables: Record = {}; + Object.keys(requestBodyJSON.variables).forEach((key) => { + jsonCompatibleVariables[key] = requestBodyJSON.variables[key]; + }); + + const finalPromptPartials: Record = {}; + for (const key in promptPartialMap) { + const partial = promptPartialMap[key]; + finalPromptPartials[key] = JSON.stringify( + Mustache.render(partial.string, jsonCompatibleVariables, {}) + ).slice(1, -1); + } + + try { + promptString = Mustache.render( + promptString, + jsonCompatibleVariables, + finalPromptPartials + ); + } catch (e) { + return { + status: 'failure', + message: `Error in parsing prompt template: ${e}`, + }; + } + + const requestBody = { + ...promptData.parameters_object, + }; + + delete requestBody['stream']; + + Object.entries(requestBodyJSON).forEach(([key, value]) => { + requestBody[key] = value; + }); + delete requestBody['variables']; + + const requestHeader = { + [HEADER_KEYS.VIRTUAL_KEY]: promptData.provider_key_slug, + [HEADER_KEYS.PROMPT_ID]: promptData.id, + [HEADER_KEYS.PROMPT_VERSION_ID]: promptData.prompt_version_id, + [HEADER_KEYS.PROMPT_SLUG]: promptSlug, + }; + + if (promptData.template_metadata) { + const metadata = JSON.parse(promptData.template_metadata); + if (metadata.azure_alias) { + requestBody['model'] = metadata.azure_alias; + } + } + + let requestUrl = env.GATEWAY_BASEPATH; + + if (promptData.ai_model_type === 'chat') { + requestBody.messages = JSON.parse(promptString); + requestUrl += '/chat/completions'; + } else { + requestBody.prompt = promptString; + requestUrl += '/completions'; + } + return { requestBody, requestHeader, requestUrl }; +}; + +export const checkRateLimits = async ( + env: any, + organisationId: string, + rateLimitObject: RateLimit, + value: string, + type: string, + units: number +) => { + const apiRateLimiterStub = env.API_RATE_LIMITER.get( + env.API_RATE_LIMITER.idFromName( + `${organisationId}-${rateLimitObject.type}-${type}-${value}` + ) + ); + const apiRes = await apiRateLimiterStub.fetch( + 'https://api.portkey.ai/v1/health', + { + method: 'POST', + body: JSON.stringify({ + windowSize: RATE_LIMIT_UNIT_TO_WINDOW_MAPPING[rateLimitObject.unit], + capacity: rateLimitObject.value, + units, + }), + } + ); + + const apiMillisecondsToNextRequest = await apiRes.json(); + + if (apiMillisecondsToNextRequest > 0) { + return { + status: false, + waitTime: apiMillisecondsToNextRequest, + }; + } + + return { + allowed: true, + waitTime: apiMillisecondsToNextRequest, + }; +}; + +export const getRateLimit = async ( + env: any, + organisationId: string, + rateLimitObject: RateLimit, + value: string, + type: RateLimiterKeyTypes +) => { + const apiRateLimiterStub = env.API_RATE_LIMITER.get( + env.API_RATE_LIMITER.idFromName( + `${organisationId}-${rateLimitObject.type}-${type}-${value}` + ) + ); + const resp = await apiRateLimiterStub.fetch( + 'https://api.portkey.ai/v1/health', + { + method: 'POST', + body: JSON.stringify({ + windowSize: RATE_LIMIT_UNIT_TO_WINDOW_MAPPING[rateLimitObject.unit], + capacity: rateLimitObject.value, + getTokens: true, + }), + } + ); + const tokenResp = await resp.json(); + return tokenResp.tokens; +}; diff --git a/src/middlewares/portkey/handlers/hooks.ts b/src/middlewares/portkey/handlers/hooks.ts new file mode 100644 index 000000000..3d0e9a449 --- /dev/null +++ b/src/middlewares/portkey/handlers/hooks.ts @@ -0,0 +1,137 @@ +import { Context } from 'hono'; +import { + HookResult, + HookResultWithLogDetails, + LogObject, + WinkyLogObject, +} from '../types'; +import { addFeedback } from './feedback'; +import { env } from 'hono/adapter'; +import { forwardHookResultsToWinky, forwardLogsToWinky } from './logger'; + +async function processFeedback( + c: Context, + hookResults: any, + orgDetailsHeader: any, + traceId: string +): Promise { + const feedbackPromises = hookResults.flatMap((hookResult: HookResult) => { + if (hookResult.feedback && hookResult.feedback.value) { + return addFeedback( + env(c), + { + ...hookResult.feedback, + trace_id: traceId, + }, + orgDetailsHeader + ); + } + return []; + }); + + return Promise.all(feedbackPromises); +} + +export async function hookHandler( + c: Context, + hookSpanId: string, + orgDetailsHeader: any, + winkyLogObject: WinkyLogObject +): Promise { + try { + const hooksManager = c.get('hooksManager'); + await hooksManager.executeHooks( + hookSpanId, + ['asyncBeforeRequestHook', 'asyncAfterRequestHook'], + { + env: env(c), + getFromCacheByKey: c.get('getFromCacheByKey'), + putInCacheWithValue: c.get('putInCacheWithValue'), + } + ); + + const span = hooksManager.getSpan(hookSpanId); + const results = span.getHooksResult(); + + const guardrailVersionIdMap: Record = {}; + // Create a map of guardrail id and version id for logging. + // This is required because we cannot mention this in opensource code. + span.getBeforeRequestHooks()?.forEach((brh: any) => { + guardrailVersionIdMap[brh.id] = brh.guardrailVersionId || ''; + }); + span.getAfterRequestHooks()?.forEach((arh: any) => { + guardrailVersionIdMap[arh.id] = arh.guardrailVersionId || ''; + }); + + await processFeedback( + c, + [...results.beforeRequestHooksResult, ...results.afterRequestHooksResult], + orgDetailsHeader, + winkyLogObject.traceId + ); + + // Add event type to each result + const beforeRequestHooksResultWithType: HookResultWithLogDetails[] = + results.beforeRequestHooksResult.map( + (result: HookResult): HookResultWithLogDetails => ({ + ...result, + event_type: 'beforeRequestHook', + guardrail_version_id: guardrailVersionIdMap[result.id] || '', + }) + ); + + const afterRequestHooksResultWithType: HookResultWithLogDetails[] = + results.afterRequestHooksResult.map( + (result: HookResult): HookResultWithLogDetails => ({ + ...result, + event_type: 'afterRequestHook', + guardrail_version_id: guardrailVersionIdMap[result.id] || '', + }) + ); + + const allResultsWithType = [ + ...beforeRequestHooksResultWithType, + ...afterRequestHooksResultWithType, + ]; + + const hookLogs: LogObject[] = []; + + // Find all checks that have a log object in it + // and push it to the hookLogs array + for (const { checks } of allResultsWithType) { + for (const check of checks) { + if (check.log) { + // add traceId to the log object + check.log.metadata.traceId = winkyLogObject.traceId; + check.log.organisationDetails = + winkyLogObject.config.organisationDetails; + hookLogs.push(check.log); + delete check.log; + } + } + } + + // Send the hook logs to the `/v1/logs` endpoint + if (hookLogs.length > 0) { + await forwardLogsToWinky(env(c), hookLogs); + } + + if (allResultsWithType.length > 0) { + // Log the hook results + await forwardHookResultsToWinky(env(c), { + generation_id: winkyLogObject.id, + trace_id: winkyLogObject.traceId, + organisation_id: winkyLogObject.config.organisationDetails.id, + workspace_slug: + winkyLogObject.config.organisationDetails.workspaceDetails.slug, + internal_trace_id: winkyLogObject.internalTraceId, + results: allResultsWithType, + organisation_details: winkyLogObject.config.organisationDetails, + }); + } + return true; + } catch (err) { + console.log('hooks err:', err); + return false; + } +} diff --git a/src/middlewares/portkey/handlers/kv.ts b/src/middlewares/portkey/handlers/kv.ts new file mode 100644 index 000000000..653440ed6 --- /dev/null +++ b/src/middlewares/portkey/handlers/kv.ts @@ -0,0 +1,80 @@ +import { INTERNAL_HEADER_KEYS } from '../globals'; + +/** + * Asynchronously fetch data from the KV store. + * + * @param {any} env - Hono environment object. + * @param {string} key - The key that needs to be retrieved from the KV store. + * @returns {Promise} - A Promise that resolves to the fetched data or null if an error occurs. + */ +export const fetchFromKVStore = async ( + env: any, + key: string +): Promise => { + if (!env.KV_STORE_WORKER_BASEPATH) { + return null; + } + const requestURL = `${env.KV_STORE_WORKER_BASEPATH}/${key}`; + const fetchOptions = { + headers: { + [INTERNAL_HEADER_KEYS.CLIENT_AUTH_SECRET]: env.CLIENT_ID, + }, + }; + try { + const response = await env.kvStoreWorker.fetch(requestURL, fetchOptions); + if (response.ok) { + return response.json(); + } else if (response.status !== 404) { + console.log( + 'invalid response from kv-store', + response.status, + await response.clone().text() + ); + } + } catch (error) { + console.log('kv fetch error', error); + } + + return null; +}; + +/** + * Asynchronously puts data into the KV store. + * + * @param {any} env - Hono environment object. + * @param {string} key - The key that needs to be stored with the value in the KV store. + * @param {string} value - The data to be stored in the KV store. + * @param {number} [expiry] - Optional expiration time for the stored data (in seconds). + * @returns {Promise} - A Promise that resolves when the data is successfully stored or logs an error if it occurs. + */ +export const putInKVStore = async ( + env: any, + key: string, + value: string, + expiry?: number +): Promise => { + if (!env.KV_STORE_WORKER_BASEPATH) { + return; + } + const requestURL = `${env.KV_STORE_WORKER_BASEPATH}/put`; + const fetchOptions = { + method: 'PUT', + body: JSON.stringify({ key, value, expiry }), + headers: { + [INTERNAL_HEADER_KEYS.CLIENT_AUTH_SECRET]: env.CLIENT_ID, + }, + }; + + try { + const response = await env.kvStoreWorker.fetch(requestURL, fetchOptions); + if (!response.ok) { + console.log( + 'failed status code from kv-store', + response.status, + await response.clone().text() + ); + } + } catch (err) { + console.log('kv put error', err); + } +}; diff --git a/src/middlewares/portkey/handlers/logger.ts b/src/middlewares/portkey/handlers/logger.ts new file mode 100644 index 000000000..38d8927f0 --- /dev/null +++ b/src/middlewares/portkey/handlers/logger.ts @@ -0,0 +1,49 @@ +import { HookResultLogObject, LogObject, WinkyLogObject } from '../types'; + +export async function forwardToWinky(env: any, winkyLogObject: WinkyLogObject) { + try { + await env.winky.fetch(env.WINKY_WORKER_BASEPATH, { + method: 'POST', + body: JSON.stringify(winkyLogObject), + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }); + } catch (error) { + console.error(error); + } +} + +export async function forwardLogsToWinky(env: any, logObjects: LogObject[]) { + try { + await env.winky.fetch(`${env.GATEWAY_BASEPATH}/logs`, { + method: 'POST', + body: JSON.stringify(logObjects), + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }); + } catch (error) { + console.error(error); + } +} + +export async function forwardHookResultsToWinky( + env: any, + logObject: HookResultLogObject +) { + try { + await env.winky.fetch(`${env.WINKY_WORKER_BASEPATH}/logs/hook-results`, { + method: 'POST', + body: JSON.stringify(logObject), + headers: { + 'Content-Type': 'application/json', + 'x-client-id-gateway': env.CLIENT_ID, + }, + }); + } catch (error) { + console.error(error); + } +} diff --git a/src/middlewares/portkey/handlers/rateLimits.ts b/src/middlewares/portkey/handlers/rateLimits.ts new file mode 100644 index 000000000..eba54ee83 --- /dev/null +++ b/src/middlewares/portkey/handlers/rateLimits.ts @@ -0,0 +1,247 @@ +import { + CACHE_STATUS, + HEADER_KEYS, + RATE_LIMIT_UNIT_TO_WINDOW_MAPPING, + RateLimiterKeyTypes, + RateLimiterTypes, +} from '../globals'; +import { RateLimit, WinkyLogObject } from '../types'; +import RedisRateLimiter from '../../../shared/services/cache/utils/rateLimiter'; +import { getRuntimeKey } from 'hono/adapter'; +import { getDefaultCache } from '../../../shared/services/cache'; +import { RedisCacheBackend } from '../../../shared/services/cache/backends/redis'; + +export function generateRateLimitKey( + organisationId: string, + rateLimitType: RateLimiterTypes, + keyType: RateLimiterKeyTypes, + key: string, + rateLimitUnit: string +) { + return `${organisationId}-${ + rateLimitType || RateLimiterTypes.REQUESTS + }-${keyType}-${key}-${rateLimitUnit}`; +} + +export function preRequestRateLimitValidator({ + env, + rateLimits, + key, + keyType, + maxTokens, + organisationId, +}: { + env: Record; + rateLimits: RateLimit[]; + key: string; + keyType: RateLimiterKeyTypes; + maxTokens: number; + organisationId: string; +}) { + const promises: Promise[] = []; + for (const rateLimit of rateLimits) { + if (rateLimit.unit && rateLimit.value) { + const rateLimitKey = generateRateLimitKey( + organisationId, + rateLimit.type, + keyType, + key, + rateLimit.unit + ); + if (getRuntimeKey() === 'node') { + const redisClient = getDefaultCache().getClient() as RedisCacheBackend; + if (!redisClient) { + console.warn( + 'you need to set the REDIS_CONNECTION_STRING environment variable for rate limits to wrok' + ); + const promise = new Promise((resolve) => { + resolve( + new Response( + JSON.stringify({ + allowed: true, + waitTime: 0, + }) + ) + ); + }); + promises.push(promise); + } + const rateLimiter = new RedisRateLimiter( + redisClient, + rateLimit.value, + RATE_LIMIT_UNIT_TO_WINDOW_MAPPING[rateLimit.unit], + rateLimitKey, + keyType + ); + if (rateLimit.type === RateLimiterTypes.TOKENS) { + promises.push( + new Promise(async (resolve) => { + const result = await rateLimiter.checkRateLimit(maxTokens, false); + resolve(new Response(JSON.stringify(result))); + }) + ); + } else { + promises.push( + new Promise(async (resolve) => { + const result = await rateLimiter.checkRateLimit(1, true); + resolve(new Response(JSON.stringify(result))); + }) + ); + } + } else { + const apiRateLimiterStub = env.RATE_LIMITER.get( + env.RATE_LIMITER.idFromName(rateLimitKey) + ); + if (rateLimit.type === RateLimiterTypes.TOKENS) { + promises.push( + apiRateLimiterStub.fetch('https://example.com', { + method: 'POST', + body: JSON.stringify({ + windowSize: RATE_LIMIT_UNIT_TO_WINDOW_MAPPING[rateLimit.unit], + capacity: rateLimit.value, + decrementTokens: false, + units: maxTokens, + keyType: keyType, + key: key, + rateLimitType: rateLimit.type, + }), + }) + ); + } else { + promises.push( + apiRateLimiterStub.fetch('https://example.com', { + method: 'POST', + body: JSON.stringify({ + windowSize: RATE_LIMIT_UNIT_TO_WINDOW_MAPPING[rateLimit.unit], + capacity: rateLimit.value, + decrementTokens: true, + units: 1, + keyType: keyType, + key: key, + rateLimitType: rateLimit.type, + }), + }) + ); + } + } + } + } + return promises; +} + +export const decrementRateLimits = async ( + env: any, + organisationId: string, + rateLimitObject: RateLimit, + cacheKey: string, + type: RateLimiterKeyTypes, + units: number +) => { + if (getRuntimeKey() === 'node') { + const redisClient = getDefaultCache().getClient() as RedisCacheBackend; + if (!redisClient) { + console.warn( + 'you need to set the REDIS_CONNECTION_STRING environment variable for rate limits to wrok' + ); + return { + allowed: true, + waitTime: 0, + }; + } + const rateLimiter = new RedisRateLimiter( + redisClient, + rateLimitObject.value, + RATE_LIMIT_UNIT_TO_WINDOW_MAPPING[rateLimitObject.unit], + cacheKey, + type + ); + const resp = await rateLimiter.decrementToken(units); + return { + allowed: resp.allowed, + waitTime: resp.waitTime, + }; + } + const apiRateLimiterStub = env.API_RATE_LIMITER.get( + env.API_RATE_LIMITER.idFromName(cacheKey) + ); + const apiRes = await apiRateLimiterStub.fetch( + 'https://api.portkey.ai/v1/health', + { + method: 'POST', + body: JSON.stringify({ + windowSize: RATE_LIMIT_UNIT_TO_WINDOW_MAPPING[rateLimitObject.unit], + capacity: rateLimitObject.value, + units, + }), + } + ); + + const apiMillisecondsToNextRequest = await apiRes.json(); + + if (apiMillisecondsToNextRequest > 0) { + return { + status: false, + waitTime: apiMillisecondsToNextRequest, + }; + } + + return { + allowed: true, + waitTime: apiMillisecondsToNextRequest, + }; +}; + +export const handleIntegrationRequestRateLimits = async ( + env: any, + chLogObject: WinkyLogObject, + units: number +) => { + const organisationDetails = chLogObject.config.organisationDetails; + const integrationDetails = + chLogObject.config?.portkeyHeaders?.[HEADER_KEYS.INTEGRATION_DETAILS]; + if (!integrationDetails) { + return; + } + const integrationDetailsObj = + typeof integrationDetails === 'string' + ? JSON.parse(integrationDetails) + : integrationDetails; + const rateLimits = integrationDetailsObj.rate_limits ?? []; + const tokenRateLimit = rateLimits?.filter( + (rl: any) => rl.type === RateLimiterTypes.TOKENS + )?.[0]; + const key = `${integrationDetailsObj.id}-${organisationDetails.workspaceDetails?.id}`; + const rateLimitKey = generateRateLimitKey( + organisationDetails.id, + tokenRateLimit.type, + RateLimiterKeyTypes.INTEGRATION_WORKSPACE, + key, + tokenRateLimit.unit + ); + if (tokenRateLimit) { + const vkRateLimits = + typeof rateLimits === 'string' ? JSON.parse(rateLimits) : rateLimits; + const requestsRateLimit = vkRateLimits?.filter( + (rl: any) => rl.type === RateLimiterTypes.TOKENS + )?.[0]; + const isCacheHit = [CACHE_STATUS.HIT, CACHE_STATUS.SEMANTIC_HIT].includes( + chLogObject.config.cacheStatus + ); + if (!isCacheHit && requestsRateLimit) { + const virtualKey = + chLogObject.config.portkeyHeaders?.[HEADER_KEYS.VIRTUAL_KEY]; + const requestRateLimitCheckObject = { + value: virtualKey, + rateLimits: requestsRateLimit, + }; + await decrementRateLimits( + env, + chLogObject.config.organisationDetails.id, + requestRateLimitCheckObject.rateLimits, + rateLimitKey, + RateLimiterKeyTypes.INTEGRATION_WORKSPACE, + units + ); + } + } +}; diff --git a/src/middlewares/portkey/handlers/realtime.ts b/src/middlewares/portkey/handlers/realtime.ts new file mode 100644 index 000000000..218a095f6 --- /dev/null +++ b/src/middlewares/portkey/handlers/realtime.ts @@ -0,0 +1,78 @@ +import { Context } from 'hono'; +import { env } from 'hono/adapter'; +import { forwardToWinky } from './logger'; +import { getDebugLogSetting, getPortkeyHeaders } from '../utils'; +import { HEADER_KEYS, MODES } from '../globals'; +import { OrganisationDetails, WinkyLogObject } from '../types'; + +export async function realtimeEventLogHandler( + c: Context, + sessionOptions: Record, + req: Record, + res: Record, + eventType: string +): Promise { + try { + const logId = crypto.randomUUID(); + let metadata: Record = {}; + try { + metadata = JSON.parse( + sessionOptions.requestHeaders['x-portkey-metadata'] + ); + } catch (err) { + metadata = {}; + } + + metadata._realtime_event_type = eventType; + metadata._category = MODES.REALTIME; + const headersObj = { + ...sessionOptions.requestHeaders, + 'x-portkey-span-name': eventType, + 'x-portkey-metadata': JSON.stringify(metadata), + }; + const portkeyHeaders = getPortkeyHeaders(headersObj); + const headersWithoutPortkeyHeaders = Object.assign({}, headersObj); //deep copy + Object.values(HEADER_KEYS).forEach((eachPortkeyHeader) => { + delete headersWithoutPortkeyHeaders[eachPortkeyHeader]; + }); + const orgDetails = JSON.parse(headersObj[HEADER_KEYS.ORGANISATION_DETAILS]); + const winkyLogObject: WinkyLogObject = { + id: logId, + traceId: portkeyHeaders['x-portkey-trace-id'], + createdAt: new Date(), + internalTraceId: sessionOptions.id, + requestMethod: 'POST', + requestURL: sessionOptions.providerOptions.requestURL, + rubeusURL: sessionOptions.providerOptions.rubeusURL, + requestHeaders: headersWithoutPortkeyHeaders, + requestBody: JSON.stringify(req), + requestBodyParams: req, + responseStatus: eventType === 'error' ? 500 : 200, + responseTime: 0, + responseBody: JSON.stringify(res), + responseHeaders: {}, + providerOptions: sessionOptions.providerOptions, + debugLogSetting: getDebugLogSetting(portkeyHeaders, orgDetails), + cacheKey: '', + config: { + organisationConfig: {}, + organisationDetails: orgDetails as OrganisationDetails, + cacheType: 'DISABLED', + retryCount: 0, + portkeyHeaders: portkeyHeaders, + proxyMode: MODES.REALTIME, + streamingMode: true, + cacheStatus: 'DISABLED', + provider: sessionOptions.providerOptions.provider, + internalTraceId: sessionOptions.id, + cacheMaxAge: null, + requestParams: req, + lastUsedOptionIndex: 0, + }, + }; + await forwardToWinky(env(c), winkyLogObject); + return true; + } catch (err: any) { + return false; + } +} diff --git a/src/middlewares/portkey/handlers/stream.ts b/src/middlewares/portkey/handlers/stream.ts new file mode 100644 index 000000000..a6c238452 --- /dev/null +++ b/src/middlewares/portkey/handlers/stream.ts @@ -0,0 +1,823 @@ +import { + ANTHROPIC, + ANYSCALE, + AZURE_OPEN_AI, + COHERE, + GOOGLE, + OPEN_AI, + PERPLEXITY_AI, + MISTRAL_AI, + TOGETHER_AI, + MODES, + OLLAMA, + NOVITA_AI, +} from '../globals'; +import { + AnthropicCompleteStreamResponse, + AnthropicMessagesStreamResponse, + CohereStreamResponse, + GoogleGenerateContentResponse, + OpenAIStreamResponse, + ParsedChunk, + TogetherAIResponse, + TogetherInferenceResponse, + OllamaCompleteStreamReponse, + OllamaChatCompleteStreamResponse, +} from '../types'; +import { parseAnthropicMessageStreamResponse } from '../utils/anthropicMessagesStreamParser'; + +export const getStreamModeSplitPattern = ( + proxyProvider: string, + requestURL: string +) => { + let splitPattern = '\n\n'; + if (proxyProvider === ANTHROPIC && requestURL.endsWith('complete')) { + splitPattern = '\r\n\r\n'; + } + if (proxyProvider === COHERE) { + splitPattern = '\n'; + } + if (proxyProvider === GOOGLE) { + splitPattern = '\r\n'; + } + if (proxyProvider === PERPLEXITY_AI) { + splitPattern = '\r\n\r\n'; + } + if (proxyProvider === OLLAMA) { + splitPattern = '\n'; + } + return splitPattern; +}; + +function parseOpenAIStreamResponse( + res: string, + splitPattern: string, + isStreamCompletionTokensAllowed: boolean +): OpenAIStreamResponse | undefined { + const arr = res.split(splitPattern); + const responseObj: OpenAIStreamResponse = { + id: '', + object: '', + created: '', + choices: [], + model: '', + usage: { + completion_tokens: 0, + }, + }; + let isConcatenationError = false; + arr.forEach((eachFullChunk) => { + if (isConcatenationError) { + return; + } + eachFullChunk = eachFullChunk + .trim() + .replace(/^data: /, '') + .replace(/^: ?(.*)$/gm, '') + .trim(); + + if (!eachFullChunk || eachFullChunk === '[DONE]') { + return responseObj; + } + + let currentIndex: number = 0; + try { + const parsedChunk: Record = JSON.parse( + eachFullChunk || '{}' + ); + if (parsedChunk.choices && parsedChunk.choices[0]?.index >= 0) { + currentIndex = parsedChunk.choices[0].index; + } + + if (parsedChunk.choices?.[0]?.delta) { + const isEmptyChunk = !parsedChunk.choices[0].delta; + responseObj.id = parsedChunk.id; + responseObj.object = parsedChunk.object; + responseObj.created = parsedChunk.created; + responseObj.model = parsedChunk.model; + + if (!responseObj.choices[currentIndex]) { + responseObj.choices[currentIndex] = { + index: '', + finish_reason: '', + message: { + role: 'assistant', + content: '', + }, + }; + } + + const currentChoice = responseObj.choices[currentIndex]; + + currentChoice.index = parsedChunk.choices[0].index; + currentChoice.finish_reason = parsedChunk.choices[0].finish_reason; + + if (!isEmptyChunk) { + const toolCall = parsedChunk.choices[0].delta?.tool_calls?.[0]; + const toolCallIndex = toolCall?.index || 0; + + if ( + currentChoice.message && + toolCall && + !currentChoice.message.tool_calls + ) { + currentChoice.message.tool_calls = []; + } + + if (currentChoice.message && toolCall) { + const currentToolCall = + currentChoice.message.tool_calls[toolCallIndex] || {}; + + if (toolCall.id) { + currentToolCall.id = toolCall.id; + } + if (toolCall.type) { + currentToolCall.type = toolCall.type; + } + if (toolCall.function) { + if (!currentToolCall.function) { + currentToolCall.function = {}; + } + if (toolCall.function.name) { + currentToolCall.function.name = toolCall.function.name; + } + if (toolCall.function.arguments) { + currentToolCall.function.arguments = + (currentToolCall.function.arguments || '') + + toolCall.function.arguments; + } + } + + currentChoice.message.tool_calls[toolCallIndex] = currentToolCall; + } + + const contentBlock = + parsedChunk.choices[0].delta?.content_blocks?.[0]; + const contentBlockIndex = contentBlock?.index || 0; + + if ( + currentChoice.message && + contentBlock && + !currentChoice.message.content_blocks + ) { + currentChoice.message.content_blocks = []; + } + + if (currentChoice.message?.content_blocks && contentBlock) { + const currentContentBlock = + currentChoice.message.content_blocks[contentBlockIndex] || {}; + + if (contentBlock.delta.thinking) { + if (!currentContentBlock.thinking) { + currentContentBlock.thinking = ''; + currentContentBlock.type = 'thinking'; + } + currentContentBlock.thinking += contentBlock.delta.thinking; + } + if (contentBlock.delta.signature) { + if (!currentContentBlock.signature) { + currentContentBlock.signature = ''; + currentContentBlock.type = 'thinking'; + } + currentContentBlock.signature += contentBlock.delta.signature; + } + if (contentBlock.delta.data) { + if (!currentContentBlock.data) { + currentContentBlock.data = ''; + currentContentBlock.type = 'redacted_thinking'; + } + currentContentBlock.data += contentBlock.delta.data; + } + if (contentBlock.delta.text) { + if (!currentContentBlock.text) { + currentContentBlock.text = ''; + currentContentBlock.type = 'text'; + } + currentContentBlock.text += contentBlock.delta.text; + } + + currentChoice.message.content_blocks[contentBlockIndex] = + currentContentBlock; + } + + if (currentChoice.message && parsedChunk.choices[0].delta.content) { + currentChoice.message.content += + parsedChunk.choices[0].delta.content; + responseObj.usage.completion_tokens++; + } + if (parsedChunk.choices[0].groundingMetadata) { + currentChoice.groundingMetadata = + parsedChunk.choices[0].groundingMetadata; + } + } + } else if ( + eachFullChunk !== '[DONE]' && + parsedChunk.choices?.[0]?.text != null + ) { + responseObj.id = parsedChunk.id; + responseObj.object = parsedChunk.object; + responseObj.created = parsedChunk.created; + responseObj.model = parsedChunk.model; + + if (!responseObj.choices[currentIndex]) { + responseObj.choices[currentIndex] = { + text: '', + index: '', + logprobs: '', + finish_reason: '', + }; + } + const currentChoice = responseObj.choices[currentIndex]; + + currentChoice.text += parsedChunk.choices[0].text; + responseObj.usage.completion_tokens++; + + currentChoice.index = parsedChunk.choices[0].index; + currentChoice.logprobs = parsedChunk.choices[0].logprobs; + currentChoice.finish_reason = parsedChunk.choices[0].finish_reason; + } + + // Portkey cache hits adds usage object with completion tokens to each stream chunk. + // This is done to avoid calculating tokens again for cache hits. + // If its not present, then increment completion_tokens on each chunk. + if (parsedChunk.usage && parsedChunk.usage.completion_tokens) { + responseObj.usage.completion_tokens = + parsedChunk.usage.completion_tokens; + } + + // Cache tokens are sent for anthropic as part of prompt-caching feature. + if ( + parsedChunk.usage && + (parsedChunk.usage.cache_read_input_tokens || + parsedChunk.usage.cache_creation_input_tokens) + ) { + responseObj.usage.cache_read_input_tokens = + parsedChunk.usage.cache_read_input_tokens; + responseObj.usage.cache_creation_input_tokens = + parsedChunk.usage.cache_creation_input_tokens; + } + + // Anthropic sends prompt and completion tokens in separate chunks + // So adding 2 different conditions are required for prompt and completion + if (parsedChunk.usage && parsedChunk.usage.prompt_tokens) { + responseObj.usage.prompt_tokens = parsedChunk.usage.prompt_tokens; + } + + if ( + parsedChunk.usage && + parsedChunk.usage.prompt_tokens && + parsedChunk.usage.completion_tokens + ) { + responseObj.usage.prompt_tokens = parsedChunk.usage.prompt_tokens; + responseObj.usage.completion_tokens = + parsedChunk.usage.completion_tokens; + responseObj.usage.total_tokens = + parsedChunk.usage.total_tokens ?? + parsedChunk.usage.prompt_tokens + parsedChunk.usage.completion_tokens; + responseObj.usage.num_search_queries = + parsedChunk.usage.num_search_queries; + } + + if (parsedChunk.usage?.completion_tokens_details) { + responseObj.usage.completion_tokens_details = + parsedChunk.usage.completion_tokens_details; + } + + if (parsedChunk.usage?.prompt_tokens_details) { + responseObj.usage.prompt_tokens_details = + parsedChunk.usage.prompt_tokens_details; + } + + if (parsedChunk.citations) { + responseObj.citations = parsedChunk.citations; + } + } catch (error) { + console.error('parseOpenAIStreamResponse error', error, eachFullChunk); + isConcatenationError = true; + } + }); + + if (isConcatenationError) { + return; + } + + if (!isStreamCompletionTokensAllowed) { + responseObj.usage.completion_tokens = 0; + } + return responseObj; +} + +function parseCohereStreamResponse( + res: string, + splitPattern: string +): CohereStreamResponse { + const arr = res.split(splitPattern); + let responseObj: CohereStreamResponse = { + id: '', + generations: [ + { + id: '', + text: '', + finish_reason: '', + }, + ], + prompt: '', + }; + let lastChunk: CohereStreamResponse | undefined; + + arr.forEach((eachFullChunk) => { + eachFullChunk = eachFullChunk.trim(); + try { + const parsedChunk: ParsedChunk = JSON.parse(eachFullChunk || '{}'); + + if (parsedChunk.is_finished && parsedChunk.response) { + lastChunk = parsedChunk.response; + } + } catch (error) { + console.error(error); + } + }); + + return lastChunk || responseObj; +} + +function parseGoogleStreamResponse( + res: string, + splitPattern: string +): GoogleGenerateContentResponse { + let response: GoogleGenerateContentResponse = { + candidates: [ + { + content: { + parts: [ + { + text: '', + }, + ], + role: 'model', + }, + finishReason: '', + index: 0, + safetyRatings: [], + }, + ], + promptFeedback: { + safetyRatings: [], + }, + }; + try { + const parsedResponse = JSON.parse(res); + parsedResponse.forEach((eachChunk: GoogleGenerateContentResponse) => { + const candidates = eachChunk.candidates; + if (eachChunk.promptFeedback) { + response.promptFeedback = eachChunk.promptFeedback; + } + candidates.forEach((candidate) => { + const index = candidate.index; + if (!response.candidates[index]) { + response.candidates[index] = { + content: { + parts: [ + { + text: '', + }, + ], + role: 'model', + }, + finishReason: '', + index: 0, + safetyRatings: [], + }; + } + + response.candidates[index].content.parts[0].text += + candidate.content.parts[0]?.text ?? ''; + response.candidates[index].finishReason = candidate.finishReason; + response.candidates[index].safetyRatings = candidate.safetyRatings; + response.candidates[index].index = candidate.index; + }); + }); + } catch (error) { + console.log('google stream error', error); + } + return response; +} + +function parseTogetherAIInferenceStreamResponse( + res: string, + splitPattern: string +): TogetherInferenceResponse { + const arr = res.split(splitPattern); + let responseObj: TogetherInferenceResponse = { + status: 'finished', + output: { + choices: [], + request_id: '', + }, + }; + try { + arr.forEach((eachFullChunk) => { + eachFullChunk.trim(); + eachFullChunk = eachFullChunk.replace(/^data: /, ''); + eachFullChunk = eachFullChunk.trim(); + let currentIndex: number = 0; + if (!eachFullChunk || eachFullChunk === '[DONE]') { + return responseObj; + } + const parsedChunk: Record = JSON.parse( + eachFullChunk || '{}' + ); + if (parsedChunk.choices && parsedChunk.choices[0]?.index >= 0) { + currentIndex = parsedChunk.choices[0].index; + } + + if (eachFullChunk !== '[DONE]' && parsedChunk.choices?.[0]?.text) { + if (!responseObj.output.choices[currentIndex]) { + responseObj.output.choices[currentIndex || 0] = { + text: '', + }; + } + responseObj.output.choices[currentIndex].text += + parsedChunk.choices[0].text; + } + }); + } catch (error) { + console.error('together-ai inference stream error', error); + } + return responseObj; +} + +function parseTogetherAICompletionsStreamResponse( + res: string, + splitPattern: string +): TogetherAIResponse { + const arr = res.split(splitPattern); + let responseObj: TogetherAIResponse = { + id: '', + object: '', + created: '', + choices: [], + model: '', + }; + try { + arr.forEach((eachFullChunk) => { + eachFullChunk.trim(); + eachFullChunk = eachFullChunk.replace(/^data: /, ''); + eachFullChunk = eachFullChunk.trim(); + let currentIndex: number = 0; + if (!eachFullChunk || eachFullChunk === '[DONE]') { + return responseObj; + } + const parsedChunk: Record = JSON.parse( + eachFullChunk || '{}' + ); + if (parsedChunk.choices && parsedChunk.choices[0]?.index >= 0) { + currentIndex = parsedChunk.choices[0].index; + } + + if (parsedChunk.choices?.[0]?.delta) { + const isEmptyChunk = parsedChunk.choices[0].delta.content + ? false + : true; + responseObj.id = parsedChunk.id; + responseObj.object = parsedChunk.object; + responseObj.created = parsedChunk.created; + responseObj.model = parsedChunk.model; + + if (!responseObj.choices[currentIndex]) { + responseObj.choices[currentIndex] = { + message: { + role: 'assistant', + content: '', + }, + }; + } + + if (!isEmptyChunk) { + responseObj.choices[currentIndex].message = { + role: 'assistant', + content: + responseObj?.choices?.[currentIndex]?.message?.content + + parsedChunk.choices[0].delta.content, + }; + } + } else if (eachFullChunk !== '[DONE]' && parsedChunk.choices?.[0]?.text) { + responseObj.id = parsedChunk.id; + responseObj.object = parsedChunk.object; + responseObj.created = parsedChunk.created; + responseObj.model = parsedChunk.model; + + if (!responseObj.choices[currentIndex]) { + responseObj.choices[currentIndex || 0] = { + text: '', + }; + } + responseObj.choices[currentIndex].text += parsedChunk.choices[0].text; + } + }); + } catch (error) { + console.error('together-ai completions stream error', error); + } + return responseObj; +} + +function parseTogetherAIStreamResponse( + res: string, + splitPattern: string, + requestURL: string +): TogetherAIResponse | TogetherInferenceResponse { + let responseType = 'completions'; + if (requestURL.endsWith('/inference')) { + responseType = 'inference'; + } + + switch (responseType) { + case 'inference': + return parseTogetherAIInferenceStreamResponse(res, splitPattern); + default: + return parseTogetherAICompletionsStreamResponse(res, splitPattern); + } +} + +function parseOllamaStreamResponse( + res: string, + splitPattern: string, + requestURL: string +): OllamaCompleteStreamReponse | OllamaChatCompleteStreamResponse { + let responseType = 'generate'; + if (requestURL.endsWith('/chat')) { + responseType = 'chat'; + } + switch (responseType) { + case 'chat': + return parseOllamaChatCompleteStreamResponse(res, splitPattern); + default: + return parseOllamaCompleteStreamResponse(res, splitPattern); + } +} + +function parseOllamaCompleteStreamResponse( + res: string, + splitPattern: string +): OllamaCompleteStreamReponse { + const arr = res.split(splitPattern).slice(0, -1); + + let responseObj: OllamaCompleteStreamReponse = { + model: '', + created_at: 0, + response: '', + done: false, + context: [], + }; + + for (let eachChunk of arr) { + eachChunk.trim(); + try { + const parsedChunk = JSON.parse(eachChunk); + if (parsedChunk.context) { + responseObj.model = parsedChunk.model; + responseObj.created_at = parsedChunk.created_at; + responseObj.done = parsedChunk.done; + responseObj.context = parsedChunk.context; + } else { + responseObj.response += parsedChunk.response; + } + } catch (error) { + console.error('ollama complete stream error', error); + } + } + return responseObj; +} + +function parseOllamaChatCompleteStreamResponse( + res: string, + splitPattern: string +): OllamaChatCompleteStreamResponse { + const arr = res.split(splitPattern).slice(0, -1); + let responseObj: OllamaChatCompleteStreamResponse = { + model: '', + created_at: '', + message: { + role: '', + content: '', + }, + done: false, + total_duration: 0, + load_duration: 0, + prompt_eval_count: 0, + prompt_eval_duration: 0, + eval_count: 0, + eval_duration: 0, + }; + for (let eachChunk of arr) { + eachChunk.trim(); + try { + const parsedChunk = JSON.parse(eachChunk); + if (parsedChunk.done) { + responseObj.model = parsedChunk.model; + responseObj.created_at = parsedChunk.created_at; + responseObj.message.role = parsedChunk.message.role; + responseObj.done = parsedChunk.done; + responseObj.total_duration = parsedChunk.total_duration; + responseObj.load_duration = parsedChunk.load_duration; + responseObj.prompt_eval_count = parsedChunk?.prompt_eval_count; + responseObj.prompt_eval_duration = parsedChunk.prompt_eval_duration; + responseObj.eval_count = parsedChunk.eval_count; + responseObj.eval_duration = parsedChunk.eval_duration; + } else { + responseObj.message.content += parsedChunk.message.content; + } + } catch (error) { + console.error('ollama chat complete stream error', error); + } + } + + return responseObj; +} + +export function parseResponse( + res: string, + aiProvider: string, + proxyMode: string, + requestURL: string, + fn: string +) { + const splitPattern = getStreamModeSplitPattern(aiProvider, requestURL); + let isStreamCompletionTokensAllowed = true; + if ([NOVITA_AI, MISTRAL_AI].includes(aiProvider)) { + isStreamCompletionTokensAllowed = false; + } + if (![MODES.PROXY_V2, MODES.PROXY].includes(proxyMode)) { + if (requestURL.includes('/v1/responses')) { + return parseOpenAIResponsesStreamResponse(res, '\n\n'); + } + if (fn === 'messages') { + return parseAnthropicMessageStreamResponse(res, splitPattern); + } + if (fn === 'imageEdit') { + return parseOpenAIImageEditStreamResponse(res, splitPattern); + } + return parseOpenAIStreamResponse( + res, + '\n\n', + isStreamCompletionTokensAllowed + ); + } + switch (aiProvider) { + case AZURE_OPEN_AI: + case OPEN_AI: + case ANYSCALE: + case PERPLEXITY_AI: + return parseOpenAIStreamResponse(res, splitPattern, true); + case MISTRAL_AI: + return parseOpenAIStreamResponse(res, splitPattern, false); + case ANTHROPIC: + return parseAnthropicMessageStreamResponse(res, splitPattern); + case COHERE: + return parseCohereStreamResponse(res, splitPattern); + case GOOGLE: + return parseGoogleStreamResponse(res, splitPattern); + case TOGETHER_AI: + return parseTogetherAIStreamResponse(res, splitPattern, requestURL); + case OLLAMA: + return parseOllamaStreamResponse(res, splitPattern, requestURL); + default: + console.error('Code should not reach here'); + throw Error('Provider not supported in streaming'); + } +} + +export async function* readStream( + reader: ReadableStreamDefaultReader, + splitPattern: string, + transformFunction: Function | undefined +) { + let buffer = ''; + let decoder = new TextDecoder(); + const state = { + lastIndex: 0, + }; + while (true) { + const { done, value } = await reader.read(); + if (done) { + if (buffer.length > 0) { + if (transformFunction) { + yield transformFunction(buffer, state); + } else { + yield buffer; + } + } + break; + } + + buffer += decoder.decode(value, { stream: true }); + // keep buffering until we have a complete chunk + + while (buffer.split(splitPattern).length > 1) { + let parts = buffer.split(splitPattern); + let lastPart = parts.pop() ?? ''; // remove the last part from the array and keep it in buffer + for (let part of parts) { + if (part.length > 0) { + if (transformFunction) { + yield transformFunction(part, state); + } else { + yield part + splitPattern; + } + } + } + + buffer = lastPart; // keep the last part (after the last '\n\n') in buffer + } + } +} + +// currently this function expects the response stream to have either error or response.completed event +// we store the response.completed event which contains the final response +export function parseOpenAIResponsesStreamResponse( + res: string, + splitPattern: string +) { + let finalResponseChunk: any; + for (let chunk of res.split(splitPattern)) { + chunk = chunk + .replace(/^event:.*\n?/gm, '') + .trim() + .replace(/^data: /, '') + .trim(); + + const obj = JSON.parse(chunk); + if ( + obj.type === 'error' || + obj.type === 'response.completed' || + obj.type === 'response.failed' || + obj.type === 'response.incomplete' + ) { + finalResponseChunk = obj; + break; + } + } + if (!finalResponseChunk) { + throw new Error('Invalid response'); + } + if (finalResponseChunk.type === 'error') { + return { + id: 'portkey_cache' + crypto.randomUUID(), + object: 'response', + created_at: Math.floor(Date.now() / 1000), + status: 'in_progress', + error: { + code: finalResponseChunk.code ?? 'server_error', + message: finalResponseChunk.message, + }, + incomplete_details: null, + instructions: null, + max_output_tokens: null, + model: '', + output: [], + parallel_tool_calls: true, + previous_response_id: null, + reasoning: { + effort: null, + generate_summary: null, + }, + store: true, + temperature: 1.0, + tool_choice: '', + tools: [], + top_p: 1.0, + truncation: '', + usage: null, + user: null, + metadata: {}, + }; + } + + return finalResponseChunk.response; +} + +export const parseOpenAIImageEditStreamResponse = ( + res: string, + splitPattern: string +) => { + for (let chunk of res.split(splitPattern)) { + chunk = chunk + .replace(/^event:.*\n?/gm, '') + .trim() + .replace(/^data: /, '') + .trim(); + const obj = JSON.parse(chunk); + if (obj.type === 'image_edit.completed') { + const response = { ...obj }; + delete response.type; + response.data = [ + { + b64_json: response.b64_json, + }, + ]; + response.created = response.created_at; + delete response.created_at; + delete response.b64_json; + return response; + } + } +}; diff --git a/src/middlewares/portkey/handlers/usage.ts b/src/middlewares/portkey/handlers/usage.ts new file mode 100644 index 000000000..77df9243c --- /dev/null +++ b/src/middlewares/portkey/handlers/usage.ts @@ -0,0 +1,143 @@ +import { AtomicCounterTypes, AtomicOperations, EntityStatus } from '../globals'; +import { + AtomicCounterRequestType, + IntegrationDetails, + OrganisationDetails, + UsageLimits, + VirtualKeyDetails, + WorkspaceDetails, +} from '../types'; + +function generateAtomicKey({ + organisationId, + type, + key, +}: Partial) { + return `${organisationId}-${type}-${key}`; +} + +async function atomicCounterHandler({ + env, + organisationId, + type, + key, + amount, + operation, + counterType, + metadata, + usageLimitId, +}: { env: Record } & Partial) { + counterType = counterType || AtomicCounterTypes.COST; + const atomicCounterStub = env.ATOMIC_COUNTER.get( + env.ATOMIC_COUNTER.idFromName( + generateAtomicKey({ organisationId, type, key }) + ) + ); + const body: Partial = { + organisationId, + type, + key, + operation, + amount, + counterType, + metadata, + usageLimitId, + }; + const apiRes = await atomicCounterStub.fetch( + 'https://api.portkey.ai/v1/health', + { + method: 'POST', + body: JSON.stringify(body), + } + ); + const data = await apiRes.json(); + return data; +} + +export async function getCurrentUsage({ + env, + organisationId, + type, + key, + counterType, + usageLimitId, +}: { env: Record } & Partial) { + return atomicCounterHandler({ + env, + organisationId, + type, + key, + counterType, + operation: AtomicOperations.GET, + usageLimitId, + }); +} + +export async function incrementUsage({ + env, + organisationId, + type, + key, + amount, + counterType, + usageLimitId, +}: { env: Record } & Partial) { + return atomicCounterHandler({ + env, + organisationId, + type, + key, + amount, + counterType, + operation: AtomicOperations.INCREMENT, + usageLimitId, + }); +} + +export async function resetUsage({ + env, + organisationId, + type, + key, + counterType, + usageLimitId, +}: { env: Record } & AtomicCounterRequestType) { + return atomicCounterHandler({ + env, + organisationId, + type, + key, + counterType, + operation: AtomicOperations.RESET, + usageLimitId, + }); +} + +export function preRequestUsageValidator({ + env, + entity, + usageLimits, + metadata, +}: { + env: Record; + entity: + | OrganisationDetails['apiKeyDetails'] + | WorkspaceDetails + | VirtualKeyDetails + | IntegrationDetails; + usageLimits: UsageLimits[]; + metadata?: Record; +}) { + let isExhausted = entity?.status === EntityStatus.EXHAUSTED; + for (const usageLimit of usageLimits) { + if (isExhausted) { + break; + } + isExhausted = usageLimit.status === EntityStatus.EXHAUSTED; + } + const isExpired = entity?.status === EntityStatus.EXPIRED; + return { + isExhausted, + isExpired, + }; +} diff --git a/src/middlewares/portkey/index.ts b/src/middlewares/portkey/index.ts new file mode 100644 index 000000000..5cabd389c --- /dev/null +++ b/src/middlewares/portkey/index.ts @@ -0,0 +1,696 @@ +import { Context } from 'hono'; +import { + getDebugLogSetting, + getMappedCacheType, + getPortkeyHeaders, + postResponseHandler, + preRequestValidator, + updateHeaders, + getStreamingMode, + addBackgroundTask, +} from './utils'; +import { + HEADER_KEYS, + RESPONSE_HEADER_KEYS, + CONTENT_TYPES, + MODES, + CACHE_STATUS, + cacheDisabledRoutesRegex, +} from './globals'; +import { BaseGuardrail, OrganisationDetails, WinkyLogObject } from './types'; +import { + getStreamModeSplitPattern, + parseResponse, + readStream, +} from './handlers/stream'; +import { getFromCache } from './handlers/cache'; +import { env } from 'hono/adapter'; +import { + createRequestFromPromptData, + getMappedConfigFromRequest, + handleVirtualKeyHeader, + getUniquePromptPartialsFromPromptMap, + getPromptPartialMap, + getGuardrailMappedConfig, +} from './handlers/helpers'; +import { fetchOrganisationPrompt } from './handlers/albus'; +import { hookHandler } from './handlers/hooks'; +import { fetchFromKVStore, putInKVStore } from './handlers/kv'; +import { realtimeEventLogHandler } from './handlers/realtime'; +import { + handleCircuitBreakerResponse, + recordCircuitBreakerFailure, +} from './circuitBreaker'; +import { fetchOrganisationDetailsFromFile } from './handlers/configFile'; + +function getContentType(headersObj: any) { + if ('content-type' in headersObj) { + return headersObj['content-type'].split(';')[0]; + } else { + return null; + } +} + +async function getRequestBodyData(req: Request, headersObj: any) { + const contentType = getContentType(headersObj); + let bodyJSON: any = {}; + let bodyFormData = new FormData(); + let requestBinary: ArrayBuffer = new ArrayBuffer(0); + + switch (contentType) { + case CONTENT_TYPES.APPLICATION_JSON: { + if (req.method === 'GET' || req.method === 'DELETE') { + bodyJSON = {}; + break; + } + bodyJSON = await req.json(); + break; + } + case CONTENT_TYPES.MULTIPART_FORM_DATA: { + bodyFormData = await req.formData(); + bodyFormData.forEach(function (value, key) { + bodyJSON[key] = value; + }); + break; + } + } + if (contentType?.startsWith(CONTENT_TYPES.GENERIC_AUDIO_PATTERN)) { + requestBinary = await req.arrayBuffer(); + } + return { bodyJSON, bodyFormData, requestBinary }; +} + +export function getMode(requestHeaders: Record, path: string) { + let mode = requestHeaders[HEADER_KEYS.MODE]?.split(' ')[0] ?? MODES.PROXY; + if ( + path === '/v1/chatComplete' || + path === '/v1/complete' || + path === '/v1/embed' + ) { + mode = MODES.RUBEUS; + } else if ( + path === '/v1/chat/completions' || + path === '/v1/messages' || + path === '/v1/completions' || + path === '/v1/embeddings' || + path === '/v1/images/generations' || + path === '/v1/images/edits' || + path === '/v1/audio/speech' || + path === '/v1/audio/transcriptions' || + path === '/v1/audio/translations' || + path.includes('/v1/batches') || + path.includes('/v1/fine_tuning') || + path.includes('/v1/files') || + path.startsWith('/v1/prompts') || + path.startsWith('/v1/responses') + ) { + mode = MODES.RUBEUS_V2; + } else if (path.startsWith('/v1/realtime')) { + mode = MODES.REALTIME; + } else if (path.indexOf('/v1/proxy') === -1) { + mode = MODES.PROXY_V2; + } + + return mode; +} + +export const portkey = () => { + return async (c: Context, next: any) => { + const reqClone = c.req.raw.clone(); + let headersObj = Object.fromEntries(c.req.raw.headers); + if (!headersObj[HEADER_KEYS.ORGANISATION_DETAILS]) { + headersObj[HEADER_KEYS.ORGANISATION_DETAILS] = JSON.stringify( + await fetchOrganisationDetailsFromFile() + ); + } + + const url = new URL(c.req.url); + const path = url.pathname; + const requestBodyData = await getRequestBodyData(reqClone, headersObj); + + const store: { + orgDetails: OrganisationDetails; + [key: string]: any; + } = { + bodyJSON: requestBodyData.bodyJSON, + bodyFormData: requestBodyData.bodyFormData, + bodyBinary: requestBodyData.requestBinary, + proxyMode: getMode(headersObj, path), + orgAPIKey: headersObj[HEADER_KEYS.API_KEY], + orgDetails: JSON.parse( + headersObj[HEADER_KEYS.ORGANISATION_DETAILS] + ) as OrganisationDetails, + requestMethod: c.req.method, + requestContentType: getContentType(headersObj), + }; + + try { + updateHeaders(headersObj, store.orgDetails); + } catch (err: any) { + return new Response( + JSON.stringify({ + status: 'failure', + message: err.message, + }), + { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + } + ); + } + + const mappedHeaders = new Headers(headersObj); + + let mappedBody = store.bodyJSON; + let mappedURL = c.req.url; + const isVirtualKeyUsageEnabled = + store.orgDetails.settings.is_virtual_key_limit_enabled; + + const { + input_guardrails: defaultOrganisationInputGuardrails, + output_guardrails: defaultOrganisationOutputGuardrails, + } = store.orgDetails?.defaults || {}; + + const { + input_guardrails: defaultWorkspaceInputGuardrails, + output_guardrails: defaultWorkspaceOutputGuardrails, + } = store.orgDetails?.workspaceDetails?.defaults || {}; + + const defaultInputGuardrails: BaseGuardrail[] = [ + ...(defaultOrganisationInputGuardrails || []).map( + (eachGuardrail: any) => ({ + slug: eachGuardrail.slug, + organisationId: store.orgDetails.id, + workspaceId: null, + }) + ), + ...(defaultWorkspaceInputGuardrails || []).map((eachGuardrail: any) => ({ + slug: eachGuardrail.slug, + organisationId: store.orgDetails.id, + workspaceId: store.orgDetails.workspaceDetails?.id, + })), + ]; + const defaultOutputGuardrails: BaseGuardrail[] = [ + ...(defaultOrganisationOutputGuardrails || []).map( + (eachGuardrail: any) => ({ + slug: eachGuardrail.slug, + organisationId: store.orgDetails.id, + workspaceId: null, + }) + ), + ...(defaultWorkspaceOutputGuardrails || []).map((eachGuardrail: any) => ({ + slug: eachGuardrail.slug, + organisationId: store.orgDetails.id, + workspaceId: store.orgDetails.workspaceDetails?.id, + })), + ]; + + const defaultGuardrails: BaseGuardrail[] = [ + ...defaultInputGuardrails, + ...defaultOutputGuardrails, + ]; + + // start: config mapping + const { + status: configStatus, + message: configStatusMessage, + mappedConfig, + configVersion, + configSlug, + promptRequestURL, + guardrailMap = {}, + integrations = [], + circuitBreakerContext, + } = await getMappedConfigFromRequest( + env(c), + mappedBody, + mappedHeaders, + store.orgAPIKey, + store.orgDetails, + path, + isVirtualKeyUsageEnabled, // TODO: Pick this up from orgDetails after guardrails deployment + mappedHeaders, + defaultGuardrails + ); + + if (configStatus === 'failure') { + return new Response( + JSON.stringify({ + status: 'failure', + message: configStatusMessage, + }), + { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + } + ); + } + + if (circuitBreakerContext) { + c.set('handleCircuitBreakerResponse', handleCircuitBreakerResponse); + c.set('recordCircuitBreakerFailure', recordCircuitBreakerFailure); + } + + if (defaultGuardrails.length > 0) { + const mappedDefaultGuardrails = getGuardrailMappedConfig( + guardrailMap, + { + input_guardrails: defaultInputGuardrails?.map( + (eachGuardrail: BaseGuardrail) => eachGuardrail.slug + ), + output_guardrails: defaultOutputGuardrails?.map( + (eachGuardrail: BaseGuardrail) => eachGuardrail.slug + ), + }, + integrations, + store.orgAPIKey, + store.orgDetails + ); + + if (mappedDefaultGuardrails.input_guardrails) { + mappedHeaders.set( + HEADER_KEYS.DEFAULT_INPUT_GUARDRAILS, + JSON.stringify(mappedDefaultGuardrails.input_guardrails) + ); + } + if (mappedDefaultGuardrails.output_guardrails) { + mappedHeaders.set( + HEADER_KEYS.DEFAULT_OUTPUT_GUARDRAILS, + JSON.stringify(mappedDefaultGuardrails.output_guardrails) + ); + } + } + // add config slug and version in header for winky logging. + if (configSlug && configVersion) { + mappedHeaders.set(HEADER_KEYS.CONFIG_SLUG, configSlug); + mappedHeaders.set(HEADER_KEYS.CONFIG_VERSION, configVersion); + } + + if (mappedConfig && store.proxyMode === MODES.RUBEUS) { + mappedBody.config = mappedConfig; + } else if (mappedConfig) { + mappedHeaders.set(HEADER_KEYS.CONFIG, JSON.stringify(mappedConfig)); + } + // end: config mapping + + // start: fetch and map prompt data + if (c.req.url.includes('/v1/prompts') && !promptRequestURL) { + const promptSlug = new URL(c.req.url).pathname?.split('/')[3]; + + const promptData = await fetchOrganisationPrompt( + env(c), + store.orgDetails.id, + store.orgDetails.workspaceDetails, + store.orgAPIKey, + promptSlug, + headersObj[HEADER_KEYS.REFRESH_PROMPT_CACHE] === 'true' + ); + if (!promptData) { + return new Response( + JSON.stringify({ + status: 'failure', + message: 'Invalid prompt id', + }), + { + status: 404, + headers: { + 'Content-Type': 'application/json', + }, + } + ); + } + + const { missingVariablePartials, uniquePartials } = + getUniquePromptPartialsFromPromptMap( + { [promptSlug]: promptData }, + mappedBody + ); + if (missingVariablePartials.length) { + return new Response( + JSON.stringify({ + status: 'failure', + message: `Missing variable partials: ${missingVariablePartials.join( + ', ' + )}`, + }), + { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + } + ); + } + + const { promptPartialMap, missingPromptPartials } = + await getPromptPartialMap( + env(c), + uniquePartials, + store.orgAPIKey, + store.orgDetails.id, + store.orgDetails.workspaceDetails + ); + if (missingPromptPartials.length) { + return new Response( + JSON.stringify({ + status: 'failure', + message: `Missing prompt partials: ${missingPromptPartials.join( + ', ' + )}`, + }), + { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + } + ); + } + const { + requestBody: promptRequestBody, + requestHeader, + requestUrl, + status: promptStatus, + message: promptStatusMessage, + } = createRequestFromPromptData( + env(c), + promptData, + promptPartialMap, + store.bodyJSON, + promptSlug + ); + + if (promptStatus == 'failure') { + return new Response( + JSON.stringify({ + status: 'failure', + message: promptStatusMessage, + }), + { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + } + ); + } + mappedBody = promptRequestBody; + + mappedURL = requestUrl; + Object.entries(requestHeader ?? {}).forEach(([key, value]) => { + mappedHeaders.set(key, value); + }); + mappedHeaders.delete('content-length'); + } else if (c.req.url.includes('/v1/prompts')) { + delete mappedBody['variables']; + if (!promptRequestURL) { + return new Response( + JSON.stringify({ + status: 'failure', + message: 'prompt completions error: Something went wrong', + }), + { + status: 500, + headers: { + 'Content-Type': 'application/json', + }, + } + ); + } + mappedURL = promptRequestURL; + } + // end: fetch and map prompt data + + // start: check and map virtual key header + const handleVirtualKeyHeaderResponse = await handleVirtualKeyHeader( + env(c), + store.orgAPIKey, + store.orgDetails.id, + store.orgDetails.workspaceDetails, + mappedHeaders, + store.proxyMode, + mappedBody, + mappedURL + ); + + if (handleVirtualKeyHeaderResponse?.status === 'failure') { + return new Response( + JSON.stringify({ + status: 'failure', + message: handleVirtualKeyHeaderResponse.message, + }), + { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + } + ); + } + // end: check and map virtual key header + + const modifiedFetchOptions: RequestInit = { + headers: mappedHeaders, + method: store.requestMethod, + }; + + if (store.requestContentType === CONTENT_TYPES.MULTIPART_FORM_DATA) { + modifiedFetchOptions.body = store.bodyFormData; + mappedHeaders.delete('content-type'); + } else if ( + store.requestContentType?.startsWith(CONTENT_TYPES.GENERIC_AUDIO_PATTERN) + ) { + modifiedFetchOptions.body = store.requestBinary; + } else if ( + store.requestMethod !== 'GET' && + store.requestMethod !== 'DELETE' && + store.requestContentType + ) { + modifiedFetchOptions.body = JSON.stringify(mappedBody); + } + + let modifiedRequest: Request; + + // TODO: Verify if we can just ```new Request(mappedURL, modifiedFetchOptions)``` for both the conditions + if (path.startsWith('/v1/prompts/')) { + modifiedRequest = new Request(mappedURL, modifiedFetchOptions); + } else { + modifiedRequest = new Request(c.req.raw, modifiedFetchOptions); + } + + const mappedHeadersObj = Object.fromEntries(mappedHeaders); + + c.req.raw = modifiedRequest; + + let executionStartTimeStamp = Date.now(); + + if (!cacheDisabledRoutesRegex.test(path)) { + c.set('getFromCache', getFromCache); + c.set('cacheIdentifier', store.orgDetails.id); + } + + if (c.req.url.includes('/v1/realtime')) { + c.set('realtimeEventParser', realtimeEventLogHandler); + } + c.set('getFromCacheByKey', fetchFromKVStore); + c.set('putInCacheWithValue', putInKVStore); + c.set('preRequestValidator', preRequestValidator); + + const headersWithoutPortkeyHeaders = Object.assign( + {}, + Object.fromEntries(mappedHeaders) + ); //deep copy + Object.values(HEADER_KEYS).forEach((eachPortkeyHeader) => { + delete headersWithoutPortkeyHeaders[eachPortkeyHeader]; + }); + + const portkeyHeaders = getPortkeyHeaders(mappedHeadersObj); + + // Main call handler is here + await next(); + + const requestOptionsArray = c.get('requestOptions'); + if (!requestOptionsArray?.length) { + return; + } + const internalTraceId = crypto.randomUUID(); + + for (const latestRequestOption of requestOptionsArray) { + const logId = crypto.randomUUID(); + const provider = latestRequestOption.providerOptions.provider; + const params = latestRequestOption.requestParams; + const isStreamingMode = getStreamingMode( + params, + provider, + latestRequestOption.providerOptions.requestURL, + latestRequestOption.providerOptions.rubeusURL + ); + const currentTimestamp = Date.now(); + const winkyBaseLog: WinkyLogObject = { + id: logId, + createdAt: latestRequestOption.createdAt, + traceId: portkeyHeaders['x-portkey-trace-id'], + internalTraceId: internalTraceId, + requestMethod: store.requestMethod, + requestURL: latestRequestOption.providerOptions.requestURL, + rubeusURL: latestRequestOption.providerOptions.rubeusURL, + finalUntransformedRequest: + latestRequestOption.finalUntransformedRequest ?? null, + transformedRequest: latestRequestOption.transformedRequest ?? null, + requestHeaders: headersWithoutPortkeyHeaders, + requestBody: JSON.stringify(mappedBody), + originalResponse: latestRequestOption.originalResponse ?? null, + requestBodyParams: params, + responseStatus: latestRequestOption.response.status, + responseTime: currentTimestamp - executionStartTimeStamp, + responseHeaders: Object.fromEntries( + latestRequestOption.response.headers + ), + cacheKey: latestRequestOption.cacheKey, + providerOptions: latestRequestOption.providerOptions, + debugLogSetting: getDebugLogSetting(portkeyHeaders, store.orgDetails), + config: { + organisationDetails: store.orgDetails, + organisationConfig: {}, + cacheType: getMappedCacheType(latestRequestOption.cacheMode), + retryCount: + Number( + c.res.headers.get(RESPONSE_HEADER_KEYS.RETRY_ATTEMPT_COUNT) + ) || 0, + portkeyHeaders: portkeyHeaders, + proxyMode: store.proxyMode, + streamingMode: isStreamingMode ?? false, + cacheStatus: latestRequestOption.cacheStatus ?? CACHE_STATUS.DISABLED, + provider: provider, + requestParams: params, + lastUsedOptionIndex: latestRequestOption.lastUsedOptionIndex, + internalTraceId: internalTraceId, + cacheMaxAge: latestRequestOption.cacheMaxAge || null, + }, + } as WinkyLogObject; + const responseClone = latestRequestOption.response; + const isCacheHit = [CACHE_STATUS.HIT, CACHE_STATUS.SEMANTIC_HIT].includes( + winkyBaseLog.config.cacheStatus + ); + const isStreamEnabledCacheHit = + isCacheHit && store.proxyMode === MODES.RUBEUS_V2; + const responseContentType = responseClone.headers.get('content-type'); + + let concatenatedStreamResponse = ''; + + if ( + isStreamingMode && + [200, 246].includes(responseClone.status) && + (!isCacheHit || isStreamEnabledCacheHit) + ) { + let splitPattern = '\n\n'; + if ([MODES.PROXY && MODES.PROXY_V2].includes(store.proxyMode)) { + splitPattern = getStreamModeSplitPattern( + provider, + winkyBaseLog.requestURL + ); + } + + (async () => { + for await (const chunk of readStream( + responseClone.body!.getReader(), + splitPattern, + undefined + )) { + concatenatedStreamResponse += chunk; + } + + const responseBodyJson = parseResponse( + concatenatedStreamResponse, + provider, + store.proxyMode, + winkyBaseLog.requestURL, + winkyBaseLog.rubeusURL + ); + winkyBaseLog.responseBody = responseBodyJson + ? JSON.stringify(responseBodyJson) + : '{ "info": "Portkey logging: Unable to log streaming response" }'; + addBackgroundTask( + c, + postResponseHandler(winkyBaseLog, responseBodyJson, env(c)) + ); + + const hooksManager = c.get('hooksManager'); + if (hooksManager && responseBodyJson) { + hooksManager.setSpanContextResponse( + latestRequestOption.hookSpanId, + responseBodyJson, + latestRequestOption.response.status + ); + } + addBackgroundTask( + c, + hookHandler( + c, + latestRequestOption.hookSpanId, + { + 'x-auth-organisation-details': + headersObj[HEADER_KEYS.ORGANISATION_DETAILS], + }, + winkyBaseLog + ) + ); + })(); + } else if ( + responseContentType?.startsWith(CONTENT_TYPES.GENERIC_AUDIO_PATTERN) || + responseContentType?.startsWith( + CONTENT_TYPES.APPLICATION_OCTET_STREAM + ) || + responseContentType?.startsWith(CONTENT_TYPES.GENERIC_IMAGE_PATTERN) + ) { + const responseBodyJson = {}; + winkyBaseLog.responseBody = JSON.stringify(responseBodyJson); + addBackgroundTask( + c, + postResponseHandler(winkyBaseLog, responseBodyJson, env(c)) + ); + } else if ( + responseContentType?.startsWith(CONTENT_TYPES.PLAIN_TEXT) || + responseContentType?.startsWith(CONTENT_TYPES.HTML) + ) { + const responseBodyJson = { + 'html-message': await responseClone.text(), + }; + winkyBaseLog.responseBody = JSON.stringify(responseBodyJson); + addBackgroundTask( + c, + postResponseHandler(winkyBaseLog, responseBodyJson, env(c)) + ); + } else if (!responseContentType && responseClone.status === 204) { + const responseBodyJson = {}; + winkyBaseLog.responseBody = JSON.stringify(responseBodyJson); + addBackgroundTask( + c, + postResponseHandler(winkyBaseLog, responseBodyJson, env(c)) + ); + } else { + const promise = new Promise(async (resolve, reject) => { + const responseBodyJson = await responseClone.json(); + winkyBaseLog.responseBody = JSON.stringify(responseBodyJson); + + await postResponseHandler(winkyBaseLog, responseBodyJson, env(c)); + await hookHandler( + c, + latestRequestOption.hookSpanId, + { + 'x-auth-organisation-details': + headersObj[HEADER_KEYS.ORGANISATION_DETAILS], + }, + winkyBaseLog + ); + resolve(true); + }); + addBackgroundTask(c, promise); + } + } + }; +}; diff --git a/src/middlewares/portkey/mustache.d.ts b/src/middlewares/portkey/mustache.d.ts new file mode 100644 index 000000000..24c859853 --- /dev/null +++ b/src/middlewares/portkey/mustache.d.ts @@ -0,0 +1,4 @@ +declare module '@portkey-ai/mustache' { + export * from '@types/mustache'; + export function getTemplateDetails(): any; +} diff --git a/src/middlewares/portkey/types.ts b/src/middlewares/portkey/types.ts new file mode 100644 index 000000000..d46d8432b --- /dev/null +++ b/src/middlewares/portkey/types.ts @@ -0,0 +1,509 @@ +import { + AtomicOperations, + AtomicKeyTypes, + AtomicCounterTypes, + EntityStatus, + RateLimiterTypes, +} from './globals'; + +export interface WinkyLogObject { + id: string; + traceId: string; + internalTraceId: string; + requestMethod: string; + requestURL: string; + rubeusURL: string; + requestHeaders: Record; + requestBody: string; + requestBodyParams: Record; + finalUntransformedRequest?: Record; + transformedRequest?: Record; + originalResponse?: Record; + createdAt: Date; + responseHeaders: Record | null; + responseBody: string | null; + responseStatus: number; + responseTime: number; + cacheKey: string; + providerOptions: Record; + debugLogSetting: boolean; + config: { + organisationConfig: Record | null; + organisationDetails: OrganisationDetails; + cacheStatus: string; + cacheType: string | null; + retryCount: number; + portkeyHeaders: Record | null; + proxyMode: string; + streamingMode: boolean; + provider: string; + requestParams: Record; + lastUsedOptionIndex: number; + internalTraceId: string; + cacheMaxAge: number | null; + }; +} + +export interface LogObjectRequest { + url: string; + method: string; + headers: Record; + body: any; + status: number; + provider?: string; +} + +export interface LogObjectResponse { + status: number; + headers: Record; + body: any; + response_time: number; + streamingMode: boolean; +} + +export interface LogObjectMetadata extends Record { + traceId?: string; + spanId?: string; + parentSpanId?: string; + spanName?: string; +} + +export interface LogObject { + request: LogObjectRequest; + response: LogObjectResponse; + metadata: LogObjectMetadata; + createdAt: string; + organisationDetails: OrganisationDetails; // Needed for auth in logging +} + +export interface HookFeedbackMetadata extends Record { + successfulChecks: string; + failedChecks: string; + erroredChecks: string; +} + +export interface HookFeedback { + value: number; + weight: number; + metadata: HookFeedbackMetadata; +} + +export interface CheckResult { + id: string; + verdict: boolean; + error?: { + name: string; + message: string; + } | null; + data: null | Record; + log?: LogObject; +} + +export interface HookResult { + verdict: boolean; + id: string; + checks: CheckResult[]; + feedback: HookFeedback; + deny: boolean; + async: boolean; +} + +export interface HookResultWithLogDetails extends HookResult { + event_type: 'beforeRequestHook' | 'afterRequestHook'; + guardrail_version_id: string; +} + +export interface HookResultLogObject { + generation_id: string; + trace_id: string; + internal_trace_id: string; + organisation_id: string; + workspace_slug: string; + results: HookResultWithLogDetails[]; + organisation_details: OrganisationDetails; +} + +interface OpenAIChoiceMessage { + role: string; + content: string; + content_blocks?: OpenAIChoiceMessageContentType[]; + tool_calls?: any; +} + +/** + * A message content type. + * @interface + */ +export interface OpenAIChoiceMessageContentType { + type: string; + text?: string; + thinking?: string; + signature?: string; + image_url?: { + url: string; + detail?: string; + }; + data?: string; +} +interface OpenAIChoice { + index: string; + finish_reason: string; + message?: OpenAIChoiceMessage; + text?: string; + logprobs?: any; + groundingMetadata?: GroundingMetadata; +} + +interface AnthropicPromptUsageTokens { + cache_read_input_tokens?: number; + cache_creation_input_tokens?: number; +} + +interface OpenAIUsage extends AnthropicPromptUsageTokens { + completion_tokens: number; + prompt_tokens?: number; + total_tokens?: number; + num_search_queries?: number; + completion_tokens_details?: { + accepted_prediction_tokens?: number; + audio_tokens?: number; + reasoning_tokens?: number; + rejected_prediction_tokens?: number; + }; + prompt_tokens_details?: { + audio_tokens?: number; + cached_tokens?: number; + }; +} + +export interface OpenAIStreamResponse { + id: string; + object: string; + created: string; + choices: OpenAIChoice[]; + model: string; + usage: OpenAIUsage; + citations?: Record; +} + +interface CohereGeneration { + id: string; + text: string; + finish_reason: string; +} + +export interface CohereStreamResponse { + id: string; + generations: CohereGeneration[]; + prompt: string; +} + +export interface ParsedChunk { + is_finished: boolean; + finish_reason: string; + response?: { + id: string; + generations: CohereGeneration[]; + prompt: string; + }; + text?: string; +} + +export interface AnthropicCompleteStreamResponse { + completion: string; + stop_reason: string; + model: string; + truncated?: boolean; + stop: null | string; + log_id: string; + exception?: any | null; +} + +export interface AnthropicMessagesStreamResponse { + id: string; + type: string; + role: string; + content: { + type: string; + text: string; + }[]; + model: string; + stop_reason: string; + stop_sequence: string | null; +} + +interface GoogleGenerateFunctionCall { + name: string; + args: Record; +} + +export interface GoogleGenerateContentResponse { + candidates: { + content: { + parts: { + text?: string; + functionCall?: GoogleGenerateFunctionCall; + }[]; + role: string; + }; + finishReason: string; + index: 0; + safetyRatings: { + category: string; + probability: string; + }[]; + }[]; + promptFeedback: { + safetyRatings: { + category: string; + probability: string; + }[]; + }; +} + +export interface TogetherAIResponse { + id: string; + choices: { + text?: string; + message?: { + role: string; + content: string; + }; + }[]; + created: string; + model: string; + object: string; +} + +export interface TogetherInferenceResponse { + status: string; + output: { + choices: { + text: string; + }[]; + request_id: string; + }; +} + +export interface OllamaCompleteResponse { + model: string; + created_at: string; + response: string; + done: boolean; + context: number[]; + total_duration: number; + load_duration: number; + prompt_eval_count: number; + prompt_eval_duration: number; + eval_count: number; + eval_duration: number; +} + +export interface OllamaCompleteStreamReponse { + model: string; + created_at: number; + response: string; + done: boolean; + context: number[]; +} + +export interface OllamaChatCompleteResponse { + model: string; + created_at: number; + message: { + role: string; + content: string; + }; + done: boolean; + total_duration: number; + load_duration: number; + prompt_eval_count: number; + prompt_eval_duration: number; + eval_count: number; + eval_duration: number; +} + +export interface OllamaChatCompleteStreamResponse { + model: string; + created_at: string; + message: { + role: string; + content: string; + }; + done: boolean; + total_duration: number; + load_duration: number; + prompt_eval_count?: number; + prompt_eval_duration: number; + eval_count: number; + eval_duration: number; +} + +export type AtomicCounterRequestType = { + operation: AtomicOperations; + type: AtomicKeyTypes; + organisationId: string; + key: string; + amount: number; + metadata?: Record; + usageLimitId?: string; + counterType?: AtomicCounterTypes; +}; + +export interface RateLimit { + type: RateLimiterTypes; + unit: string; + value: number; +} + +interface ApiKeyDefaults { + config_slug?: string; + config_id?: string; + metadata?: Record; + input_guardrails?: OrganisationDefaults['input_guardrails']; + output_guardrails?: OrganisationDefaults['output_guardrails']; + allow_config_override?: boolean; +} + +interface OrganisationDefaults { + input_guardrails?: string[]; + output_guardrails?: string[]; +} + +export interface UsageLimits { + id?: string; + type: AtomicCounterTypes; + credit_limit: number | null; + alert_threshold: number | null; + is_threshold_alerts_sent: boolean | null; + is_exhausted_alerts_sent: boolean | null; + periodic_reset: string | null; + current_usage: number | null; + last_reset_at: string | null; + metadata: Record; + status: EntityStatus; +} + +interface ApiKeyDetails { + id: string; + expiresAt?: string; + scopes: string[]; + rateLimits?: RateLimit[]; + defaults: ApiKeyDefaults; + usageLimits: UsageLimits[]; + systemDefaults?: { + user_name: string; + user_key_metadata_override: boolean; + }; +} + +export interface WorkspaceDetails { + id: string; + slug: string; + defaults: ApiKeyDefaults; + usage_limits: UsageLimits[]; + rate_limits: RateLimit[]; + status: EntityStatus; +} + +export interface OrganisationDetails { + id: string; + ownerId?: string; + name?: string; + settings: Record; + isFirstGenerationDone: boolean; + enterpriseSettings?: Record; + workspaceDetails: WorkspaceDetails; + scopes: string[]; + rateLimits?: RateLimit[]; + defaults: ApiKeyDefaults; + usageLimits: UsageLimits[]; + status: EntityStatus; + apiKeyDetails: { + id: string; + key: string; + scopes: string[]; + rateLimits?: RateLimit[]; + defaults: ApiKeyDefaults; + usageLimits: UsageLimits[]; + status: EntityStatus; + expiresAt?: string; + systemDefaults?: { + user_name?: string; + user_key_metadata_override?: boolean; + }; + }; + organisationDefaults: OrganisationDefaults; +} + +export interface VirtualKeyDetails { + id: string; + slug: string; + usage_limits: UsageLimits[]; + rate_limits: RateLimit[]; + status: EntityStatus; + workspace_id: string; + organisation_id: string; + expires_at: string; +} + +export interface IntegrationDetails { + id: string; + slug: string; + usage_limits: UsageLimits[]; + rate_limits: RateLimit[]; + status: EntityStatus; + allow_all_models: boolean; + models: { + slug: string; + status: EntityStatus.ACTIVE | EntityStatus.ARCHIVED; + pricing_config: Record; + }[]; +} + +export interface BaseGuardrail { + slug: string; + organisationId: string; + workspaceId: string | null; +} + +export interface GroundingMetadata { + webSearchQueries?: string[]; + searchEntryPoint?: { + renderedContent: string; + }; + groundingSupports?: Array<{ + segment: { + startIndex: number; + endIndex: number; + text: string; + }; + groundingChunkIndices: number[]; + confidenceScores: number[]; + }>; + retrievalMetadata?: { + webDynamicRetrievalScore: number; + }; +} + +export interface IntegrationDetails { + id: string; + slug: string; + usage_limits: UsageLimits[]; + rate_limits: RateLimit[]; + status: EntityStatus; + allow_all_models: boolean; + models: { + slug: string; + status: EntityStatus.ACTIVE | EntityStatus.ARCHIVED; + pricing_config: Record; + }[]; +} + +export type AtomicCounterResponseType = { + value: number; + success: boolean; + message?: string; + type?: AtomicKeyTypes; + key?: string; +}; diff --git a/src/middlewares/portkey/utils.ts b/src/middlewares/portkey/utils.ts new file mode 100644 index 000000000..cd663e8b8 --- /dev/null +++ b/src/middlewares/portkey/utils.ts @@ -0,0 +1,717 @@ +import { Env, Context } from 'hono'; +import { env, getRuntimeKey } from 'hono/adapter'; +import { + AZURE_OPEN_AI, + BEDROCK, + CACHE_STATUS, + EntityStatus, + GOOGLE, + HEADER_KEYS, + MODES, + RateLimiterKeyTypes, + RateLimiterTypes, + VERTEX_AI, +} from './globals'; +import { putInCache } from './handlers/cache'; +import { forwardToWinky } from './handlers/logger'; +import { + IntegrationDetails, + OrganisationDetails, + VirtualKeyDetails, + WinkyLogObject, +} from './types'; +import { checkRateLimits, getRateLimit } from './handlers/helpers'; +import { preRequestUsageValidator } from './handlers/usage'; +import { + handleIntegrationRequestRateLimits, + preRequestRateLimitValidator, +} from './handlers/rateLimits'; +import { settings } from '../../../initializeSettings'; + +const runtime = getRuntimeKey(); + +export const getMappedCacheType = (cacheHeader: string) => { + if (!cacheHeader) { + return null; + } + + if (['simple', 'true'].includes(cacheHeader)) { + return 'simple'; + } + + if (cacheHeader === 'semantic') { + return 'semantic'; + } + + return null; +}; + +export function getPortkeyHeaders( + headersObj: Record +): Record { + let final: Record = {}; + const pkHeaderKeys = Object.values(HEADER_KEYS); + Object.keys(headersObj).forEach((key: string) => { + if (pkHeaderKeys.includes(key)) { + final[key] = headersObj[key]; + } + }); + delete final[HEADER_KEYS.ORGANISATION_DETAILS]; + return final; +} + +export async function postResponseHandler( + winkyBaseLog: WinkyLogObject, + responseBodyJson: Record, + env: Env +): Promise { + const cacheResponseBody = { ...responseBodyJson }; + // Put in Cache if needed + if ( + responseBodyJson && + winkyBaseLog.config.cacheType && + [ + CACHE_STATUS.MISS, + CACHE_STATUS.SEMANTIC_MISS, + CACHE_STATUS.REFRESH, + ].includes(winkyBaseLog.config.cacheStatus) && + winkyBaseLog.responseStatus === 200 && + winkyBaseLog.config.organisationDetails?.id && + winkyBaseLog.debugLogSetting + ) { + const cacheKeyUrl = [MODES.PROXY, MODES.PROXY_V2, MODES.API].includes( + winkyBaseLog.config.proxyMode + ) + ? winkyBaseLog.requestURL + : winkyBaseLog.rubeusURL; + + delete cacheResponseBody.hook_results; + await putInCache( + env, + { + ...winkyBaseLog.requestHeaders, + ...winkyBaseLog.config.portkeyHeaders, + }, + winkyBaseLog.requestBodyParams, + cacheResponseBody, + cacheKeyUrl, + winkyBaseLog.config.organisationDetails.id, + winkyBaseLog.config.cacheType, + winkyBaseLog.config.cacheMaxAge + ); + } + + // Log this request + if (env.WINKY_WORKER_BASEPATH) { + await forwardToWinky(env, winkyBaseLog); + } else if (settings) { + await handleTokenRateLimit(winkyBaseLog, responseBodyJson, env); + // TODO: make logs endpoint configurable + } + return; +} + +export const getStreamingMode = ( + reqBody: Record, + provider: string, + requestUrl: string, + rubeusUrl: string +): boolean => { + if ( + [GOOGLE, VERTEX_AI].includes(provider) && + requestUrl.indexOf('stream') > -1 + ) { + return true; + } + if ( + provider === BEDROCK && + (requestUrl.indexOf('invoke-with-response-stream') > -1 || + requestUrl.indexOf('converse-stream') > -1) + ) { + return true; + } + if (rubeusUrl === 'imageEdit') { + return reqBody.get('stream') === 'true'; + } + return reqBody?.stream; +}; + +/** + * Gets the debug log setting based on request headers and organisation details. + * Priority is given to x-portkey-debug header if its passed in request. + * Else default org level setting is considered. + * @param {Record} requestHeaders - The headers from the incoming request. + * @param {OrganisationDetails} organisationDetails - The details of the organisation. + * @returns {boolean} The debug log setting. + */ +export function getDebugLogSetting( + requestHeaders: Record, + organisationDetails: OrganisationDetails +): boolean { + const debugSettingHeader = requestHeaders['x-portkey-debug']?.toLowerCase(); + + if (debugSettingHeader === 'false') return false; + else if (debugSettingHeader === 'true') return true; + + const organisationDebugLogSetting = organisationDetails.settings?.debug_log; + + if (organisationDebugLogSetting === 0) return false; + + return true; +} + +export async function preRequestValidator( + c: Context, + options: Record, + requestHeaders: Record, + params: Record, + metadata: Record +) { + const organisationDetails = requestHeaders[HEADER_KEYS.ORGANISATION_DETAILS] + ? JSON.parse(requestHeaders[HEADER_KEYS.ORGANISATION_DETAILS]) + : null; + const virtualKeyDetails = + options.virtualKeyDetails || + requestHeaders[HEADER_KEYS.VIRTUAL_KEY_DETAILS]; + let virtualKeyDetailsObj: VirtualKeyDetails | null = null; + if (virtualKeyDetails) { + virtualKeyDetailsObj = + typeof virtualKeyDetails === 'string' + ? JSON.parse(virtualKeyDetails) + : virtualKeyDetails; + } + const integrationDetails = + options.integrationDetails || + requestHeaders[HEADER_KEYS.INTEGRATION_DETAILS]; + let integrationDetailsObj: IntegrationDetails | null = null; + if (integrationDetails) { + integrationDetailsObj = + typeof integrationDetails === 'string' + ? JSON.parse(integrationDetails) + : integrationDetails; + } + const model = params?.model || null; + // Validate Statuses + let errorResponse = await validateEntityStatus( + c, + options, + organisationDetails, + virtualKeyDetailsObj, + integrationDetailsObj, + model, + metadata + ); + if (errorResponse) { + return errorResponse; + } + // Validate Rate Limits + const maxTokens = params.max_tokens || params.max_completion_tokens || 1; + errorResponse = await validateEntityTokenRateLimits( + c, + organisationDetails, + virtualKeyDetailsObj, + integrationDetailsObj, + maxTokens + ); + if (errorResponse) { + return errorResponse; + } +} + +async function validateEntityStatus( + c: Context, + options: Record, + organisationDetails: OrganisationDetails, + virtualKeyDetailsObj: VirtualKeyDetails | null, + integrationDetails: IntegrationDetails | null, + model: string | null, + metadata: Record +) { + let isExhausted = false; + let isExpired = false; + let errorMessage = ''; + let errorType = ''; + const [ + { isExhausted: isApiKeyExhausted, isExpired: isApiKeyExpired }, + { isExhausted: isWorkspaceExhausted, isExpired: isWorkspaceExpired }, + { isExhausted: isVirtualKeyExhausted, isExpired: isVirtualKeyExpired }, + { isExhausted: isIntegrationExhausted, isExpired: isIntegrationExpired }, + ] = await Promise.all([ + preRequestUsageValidator({ + env: env(c), + entity: organisationDetails.apiKeyDetails, + usageLimits: organisationDetails.apiKeyDetails?.usageLimits || [], + metadata, + }), + preRequestUsageValidator({ + env: env(c), + entity: organisationDetails.workspaceDetails, + usageLimits: organisationDetails.workspaceDetails?.usage_limits || [], + metadata, + }), + validateVirtualKeyStatus(c, virtualKeyDetailsObj, metadata), + validateIntegrationStatus(c, integrationDetails, metadata), + ]); + + if (isApiKeyExhausted) { + isExhausted = true; + errorMessage = `Portkey API Key Usage Limit ${hash( + organisationDetails.apiKeyDetails.key + )} Exceeded`; + errorType = 'api_key_exhaust_error'; + } else if (isWorkspaceExhausted) { + isExhausted = true; + errorMessage = `Portkey Workspace Usage Limit ${hash( + organisationDetails.workspaceDetails.slug + )} Exceeded`; + errorMessage = 'Portkey Workspace Usage Limit Exceeded'; + errorType = 'workspace_exhaust_error'; + } else if (isVirtualKeyExhausted) { + isExhausted = true; + errorMessage = `Portkey Virtual Key Usage Limit ${hash( + virtualKeyDetailsObj?.slug + )} Exceeded`; + errorType = 'virtual_key_exhaust_error'; + } else if (isIntegrationExhausted) { + isExhausted = true; + errorMessage = `Portkey Integration Usage Limit ${hash( + integrationDetails?.slug + )} Exceeded`; + errorType = 'integration_exhaust_error'; + } + if (isApiKeyExpired) { + isExpired = true; + errorMessage = `Portkey API Key Usage Limit ${hash( + organisationDetails.apiKeyDetails.key + )} Expired`; + errorType = 'api_key_expired_error'; + } else if (isWorkspaceExpired) { + isExpired = true; + errorMessage = `Portkey Workspace Usage Limit ${hash( + organisationDetails.workspaceDetails.slug + )} Expired`; + errorType = 'workspace_expired_error'; + } else if (isVirtualKeyExpired) { + isExpired = true; + errorMessage = `Portkey Virtual Key Usage Limit ${hash( + virtualKeyDetailsObj?.slug + )} Expired`; + errorType = 'virtual_key_expired_error'; + } else if (isIntegrationExpired) { + isExpired = true; + errorMessage = `Portkey Integration Usage Limit ${hash( + integrationDetails?.slug + )} Expired`; + errorType = 'integration_expired_error'; + } + if (isExhausted) { + return new Response( + JSON.stringify({ + error: { + message: errorMessage, + type: errorType, + param: null, + code: '04', + }, + }), + { + headers: { + 'content-type': 'application/json', + }, + status: 412, + } + ); + } + if (isExpired) { + return new Response( + JSON.stringify({ + error: { + message: errorMessage, + type: errorType, + param: null, + code: '01', + }, + }), + { + headers: { + 'content-type': 'application/json', + }, + status: 401, + } + ); + } + if ( + integrationDetails && + model && + !validateIntegrationModel(c, options, integrationDetails, model) + ) { + return new Response( + JSON.stringify({ + error: { + message: `Model ${model} is not allowed for this integration`, + type: 'model_not_allowed_error', + param: null, + code: null, + }, + }) + ); + } +} + +async function validateEntityTokenRateLimits( + c: Context, + organisationDetails: OrganisationDetails, + virtualKeyDetails: VirtualKeyDetails | null, + integrationDetails: IntegrationDetails | null, + maxTokens: number +) { + const rateLimitChecks: any[] = []; + rateLimitChecks.push( + ...validateApiKeyTokenRateLimits(c, organisationDetails, maxTokens) + ); + rateLimitChecks.push( + ...validateWorkspaceTokenRateLimits(c, organisationDetails, maxTokens) + ); + if (virtualKeyDetails) { + rateLimitChecks.push( + ...validateVirtualKeyRateLimits(c, virtualKeyDetails, maxTokens) + ); + } + if (integrationDetails) { + rateLimitChecks.push( + ...validateIntegrationRateLimits( + c, + organisationDetails, + integrationDetails, + maxTokens + ) + ); + } + const results = await Promise.all(rateLimitChecks); + let isRateLimitExceeded = false; + let errorMessage = ''; + let errorType = ''; + for (const resp of results) { + const result = await resp.json(); + if (result.allowed === false && !errorMessage) { + isRateLimitExceeded = true; + if (result.keyType === RateLimiterKeyTypes.API_KEY) { + errorMessage = `Portkey API Key ${hash( + result.key + )} Rate Limit Exceeded`; + errorType = 'api_key_rate_limit_error'; + } else if (result.keyType === RateLimiterKeyTypes.WORKSPACE) { + errorMessage = `Portkey Workspace ${hash( + result.key + )} Rate Limit Exceeded`; + errorType = 'workspace_rate_limit_error'; + } else if (result.keyType === RateLimiterKeyTypes.VIRTUAL_KEY) { + errorMessage = `Portkey Virtual Key ${hash( + result.key + )} Rate Limit Exceeded`; + errorType = 'virtual_key_rate_limit_error'; + } else if (result.keyType === RateLimiterKeyTypes.INTEGRATION_WORKSPACE) { + errorMessage = `Portkey Integration ${hash( + result.key + )} Rate Limit Exceeded`; + errorType = 'integration_rate_limit_error'; + } + } + } + if (isRateLimitExceeded) { + return new Response( + JSON.stringify({ + error: { + message: errorMessage, + type: errorType, + param: null, + code: null, + }, + }), + { + headers: { + 'content-type': 'application/json', + }, + status: 429, + } + ); + } +} + +function validateApiKeyTokenRateLimits( + c: Context, + organisationDetails: OrganisationDetails, + maxTokens: number +) { + // validate only token rate limits + const rateLimits = organisationDetails.apiKeyDetails?.rateLimits?.filter( + (rateLimit) => rateLimit.type === RateLimiterTypes.TOKENS + ); + return preRequestRateLimitValidator({ + env: env(c), + rateLimits: rateLimits || [], + key: organisationDetails.apiKeyDetails?.key, + keyType: RateLimiterKeyTypes.API_KEY, + maxTokens, + organisationId: organisationDetails.id, + }); +} + +function validateWorkspaceTokenRateLimits( + c: Context, + organisationDetails: OrganisationDetails, + maxTokens: number +) { + // validate only token rate limits + const workspaceRateLimits = + organisationDetails.workspaceDetails?.rate_limits?.filter( + (rateLimit) => rateLimit.type === RateLimiterTypes.TOKENS + ); + return preRequestRateLimitValidator({ + env: env(c), + rateLimits: workspaceRateLimits || [], + key: organisationDetails.workspaceDetails?.slug, + keyType: RateLimiterKeyTypes.WORKSPACE, + maxTokens, + organisationId: organisationDetails.id, + }); +} + +function validateVirtualKeyStatus( + c: Context, + virtualKeyDetailsObj: VirtualKeyDetails | null, + metadata: Record +) { + if (!virtualKeyDetailsObj) { + return { + isExhausted: false, + isExpired: false, + }; + } + return preRequestUsageValidator({ + env: env(c), + entity: virtualKeyDetailsObj, + usageLimits: virtualKeyDetailsObj?.usage_limits || [], + metadata, + }); +} + +function validateVirtualKeyRateLimits( + c: Context, + virtualKeyDetailsObj: VirtualKeyDetails, + maxTokens: number +) { + return preRequestRateLimitValidator({ + env: env(c), + rateLimits: virtualKeyDetailsObj?.rate_limits || [], + key: virtualKeyDetailsObj?.id, + keyType: RateLimiterKeyTypes.VIRTUAL_KEY, + maxTokens, + organisationId: virtualKeyDetailsObj?.organisation_id, + }); +} + +function validateIntegrationStatus( + c: Context, + integrationDetails: IntegrationDetails | null, + metadata: Record +) { + if (!integrationDetails) { + return { + isExhausted: false, + isExpired: false, + }; + } + return preRequestUsageValidator({ + env: env(c), + entity: integrationDetails, + usageLimits: integrationDetails?.usage_limits || [], + metadata, + }); +} + +function validateIntegrationModel( + c: Context, + options: Record, + integrationDetails: IntegrationDetails, + model: string +) { + let isModelAllowed = true; + if (integrationDetails && model) { + const allowAllModels = integrationDetails.allow_all_models; + let modelDetails; + if (!allowAllModels) { + modelDetails = integrationDetails.models?.find((m) => m.slug === model); + // Preserve old logic for backward compatibility. + // TODO: Remove this once we have migrated all the users to the new logic (alias as model). + if ( + options.provider === AZURE_OPEN_AI && + options.azureModelName && + !modelDetails + ) { + modelDetails = integrationDetails.models?.find( + (m) => m.slug === options.azureModelName + ); + } + + if (modelDetails) { + options.modelPricingConfig = modelDetails.pricing_config; + } + + if (!modelDetails || modelDetails?.status === EntityStatus.ARCHIVED) { + isModelAllowed = false; + } + } + } + return isModelAllowed; +} + +function validateIntegrationRateLimits( + c: Context, + organisationDetails: OrganisationDetails, + integrationDetails: IntegrationDetails, + maxTokens: number +) { + return preRequestRateLimitValidator({ + env: env(c), + rateLimits: integrationDetails.rate_limits, + key: `${integrationDetails.id}-${organisationDetails.workspaceDetails.id}`, + keyType: RateLimiterKeyTypes.INTEGRATION_WORKSPACE, + maxTokens, + organisationId: organisationDetails.id, + }); +} +export const hash = (string: string | null | undefined) => { + if (string === null || string === undefined) return null; + //remove bearer from the string + if (string.startsWith('Bearer ')) string = string.slice(7, string.length); + return ( + string.slice(0, 2) + + '********' + + string.slice(string.length - 3, string.length) + ); +}; + +/** + * Updates headers object with default config_slug and metadata of an api key. + * + * @param {Object} headersObj - The original headers object to update. + * @param {OrganisationDetails} orgDetails - The organisation details object. + */ +export function updateHeaders( + headersObj: Record, + orgDetails: OrganisationDetails +) { + if ( + headersObj[HEADER_KEYS.CONFIG] && + orgDetails.apiKeyDetails?.defaults?.config_slug && + orgDetails.apiKeyDetails?.defaults?.allow_config_override === false + ) { + throw new Error('Cannot override default config set for this API key.'); + } + + if ( + !headersObj[HEADER_KEYS.CONFIG] && + (orgDetails.apiKeyDetails?.defaults?.config_slug || + orgDetails.workspaceDetails?.defaults?.config_slug) + ) { + headersObj[HEADER_KEYS.CONFIG] = (orgDetails.apiKeyDetails?.defaults + ?.config_slug || + orgDetails.workspaceDetails?.defaults?.config_slug) as string; + } + + if ( + orgDetails.workspaceDetails?.defaults?.metadata || + orgDetails.apiKeyDetails?.defaults?.metadata || + orgDetails.apiKeyDetails?.systemDefaults?.user_name + ) { + let finalMetadata: Record = {}; + try { + const incomingMetadata = headersObj[HEADER_KEYS.METADATA] + ? JSON.parse(headersObj[HEADER_KEYS.METADATA]) + : {}; + finalMetadata = { + ...incomingMetadata, + ...(orgDetails.apiKeyDetails?.defaults?.metadata || {}), + ...(orgDetails.workspaceDetails?.defaults?.metadata || {}), + }; + } catch (err) { + finalMetadata = { + ...(orgDetails.apiKeyDetails?.defaults?.metadata || {}), + ...(orgDetails.workspaceDetails?.defaults?.metadata || {}), + }; + } + const systemUserName = orgDetails.apiKeyDetails?.systemDefaults?.user_name; + if (systemUserName) { + if ( + orgDetails.apiKeyDetails?.systemDefaults?.user_key_metadata_override + ) { + // if override, precedence to existing user passed + finalMetadata._user = finalMetadata._user || systemUserName; + } else { + // use system user name irrespective of passed + finalMetadata._user = systemUserName; + } + } + headersObj[HEADER_KEYS.METADATA] = JSON.stringify(finalMetadata); + } + + // These 2 headers can only be injected by Portkey internally. + // They are not meant to be passed by the user. So we enforce this by deleting them. + delete headersObj[HEADER_KEYS.DEFAULT_INPUT_GUARDRAILS]; + delete headersObj[HEADER_KEYS.DEFAULT_OUTPUT_GUARDRAILS]; +} + +export function constructAzureFoundryURL( + modelConfig: { + azureDeploymentType?: string; + azureDeploymentName?: string; + azureRegion?: string; + azureEndpointName?: string; + } = {} +) { + if (modelConfig.azureDeploymentType === 'serverless') { + return `https://${modelConfig.azureDeploymentName?.toLowerCase()}.${ + modelConfig.azureRegion + }.models.ai.azure.com`; + } else if (modelConfig.azureDeploymentType === 'managed') { + return `https://${modelConfig.azureEndpointName}.${modelConfig.azureRegion}.inference.ml.azure.com/score`; + } +} + +export const addBackgroundTask = ( + c: Context, + promise: Promise +) => { + if (runtime === 'workerd') { + c.executionCtx.waitUntil(promise); + } + // in other runtimes, the promise resolves in the background +}; + +export const handleTokenRateLimit = ( + winkyBaseLog: WinkyLogObject, + responseBodyJson: Record, + env: any +) => { + let totalTokens = 0; + if (winkyBaseLog.responseStatus >= 200 && winkyBaseLog.responseStatus < 300) { + switch (winkyBaseLog.rubeusURL) { + case 'chatComplete': + case 'complete': + totalTokens = responseBodyJson.usage.total_tokens; + break; + case 'messages': + totalTokens = + responseBodyJson.usage?.input_tokens + + (responseBodyJson.usage?.cache_creation_input_tokens ?? 0) + + (responseBodyJson.usage?.cache_read_input_tokens ?? 0) + + responseBodyJson.usage.output_tokens; + break; + default: + totalTokens = 0; + } + // do not await results + handleIntegrationRequestRateLimits(env, winkyBaseLog, totalTokens); + } +}; diff --git a/src/middlewares/portkey/utils/anthropicMessagesStreamParser.ts b/src/middlewares/portkey/utils/anthropicMessagesStreamParser.ts new file mode 100644 index 000000000..7e712865a --- /dev/null +++ b/src/middlewares/portkey/utils/anthropicMessagesStreamParser.ts @@ -0,0 +1,159 @@ +const JSON_BUF_PROPERTY = '__json_buf'; +export type TracksToolInput = any; + +function tracksToolInput(content: any): content is TracksToolInput { + return content.type === 'tool_use' || content.type === 'server_tool_use'; +} + +export const parseAnthropicMessageStreamResponse = ( + res: string, + splitPattern: string +): any => { + const arr = res.split(splitPattern); + let snapshot: any | undefined; + try { + for (let eachFullChunk of arr) { + eachFullChunk = eachFullChunk.trim(); + eachFullChunk = eachFullChunk + .replace(/^event:.*$/gm, '') + .replace(/^\s*\n/gm, ''); + eachFullChunk = eachFullChunk.replace(/^data: /, ''); + eachFullChunk = eachFullChunk.trim(); + const event: any = JSON.parse(eachFullChunk || '{}'); + + if (event.type === 'ping') { + continue; + } + + if (event.type === 'message_start') { + snapshot = event.message; + if (!snapshot.usage) + snapshot.usage = { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }; + continue; + } + + if (!snapshot) { + throw new Error('Unexpected ordering of events'); + } + + switch (event.type) { + case 'message_delta': + snapshot.stop_reason = event.delta.stop_reason; + snapshot.stop_sequence = event.delta.stop_sequence; + snapshot.usage.output_tokens = event.usage.output_tokens; + + // Update other usage fields if they exist in the event + if (event.usage.input_tokens != null) { + snapshot.usage.input_tokens = event.usage.input_tokens; + } + + if (event.usage.cache_creation_input_tokens != null) { + snapshot.usage.cache_creation_input_tokens = + event.usage.cache_creation_input_tokens; + } + + if (event.usage.cache_read_input_tokens != null) { + snapshot.usage.cache_read_input_tokens = + event.usage.cache_read_input_tokens; + } + + if (event.usage.server_tool_use != null) { + snapshot.usage.server_tool_use = event.usage.server_tool_use; + } + break; + // we return from here + case 'message_stop': + return snapshot; + case 'content_block_start': + snapshot.content.push({ ...event.content_block }); + break; + case 'content_block_delta': { + const snapshotContent = snapshot.content.at(event.index); + + switch (event.delta.type) { + case 'text_delta': { + if (snapshotContent?.type === 'text') { + snapshot.content[event.index] = { + ...snapshotContent, + text: (snapshotContent.text || '') + event.delta.text, + }; + } + break; + } + case 'citations_delta': { + if (snapshotContent?.type === 'text') { + snapshot.content[event.index] = { + ...snapshotContent, + citations: [ + ...(snapshotContent.citations ?? []), + event.delta.citation, + ], + }; + } + break; + } + case 'input_json_delta': { + if (snapshotContent && tracksToolInput(snapshotContent)) { + // we need to keep track of the raw JSON string as well so that we can + // re-parse it for each delta, for now we just store it as an untyped + // non-enumerable property on the snapshot + let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || ''; + jsonBuf += event.delta.partial_json; + + const newContent = { ...snapshotContent }; + Object.defineProperty(newContent, JSON_BUF_PROPERTY, { + value: jsonBuf, + enumerable: false, + writable: true, + }); + + if (jsonBuf) { + try { + // only set input if it's valid JSON + newContent.input = JSON.parse(jsonBuf); + } catch (error) { + // ignore error + } + } + snapshot.content[event.index] = newContent; + } + break; + } + case 'thinking_delta': { + if (snapshotContent?.type === 'thinking') { + snapshot.content[event.index] = { + ...snapshotContent, + thinking: snapshotContent.thinking + event.delta.thinking, + }; + } + break; + } + case 'signature_delta': { + if (snapshotContent?.type === 'thinking') { + snapshot.content[event.index] = { + ...snapshotContent, + signature: event.delta.signature, + }; + } + break; + } + } + break; + } + case 'content_block_stop': + break; + } + } + } catch (error: any) { + console.error({ + message: `parseAnthropicMessageStreamResponse: ${error.message}`, + }); + snapshot = undefined; + } + return snapshot; +}; diff --git a/src/public/index.html b/src/public/index.html index 9bd7e77e2..d7328b158 100644 --- a/src/public/index.html +++ b/src/public/index.html @@ -508,6 +508,435 @@ } } + /* integrations.css */ + /* Integrations styles */ + .card.integrations-card { + margin-top: 1rem; + max-width: 1000px; + } + + .integrations-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 1rem; + } + + .integrations-list { + display: flex; + flex-direction: column; + gap: 1rem; + } + + .integration-item { + background-color: white; + border: 1px solid #e5e7eb; + border-radius: 0.5rem; + overflow: hidden; + transition: all 0.2s ease; + } + + .integration-item:hover { + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); + } + + .integration-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 1rem; + background-color: #f9fafb; + cursor: pointer; + border-bottom: 1px solid #e5e7eb; + } + + .integration-header:hover { + background-color: #f3f4f6; + } + + .integration-title { + display: flex; + align-items: center; + gap: 0.5rem; + font-weight: 600; + color: #374151; + } + + .integration-provider { + background-color: #3b82f6; + color: white; + padding: 0.25rem 0.5rem; + border-radius: 0.25rem; + font-size: 0.75rem; + font-weight: 500; + } + + .integration-slug { + color: #6b7280; + font-size: 0.875rem; + } + + .integration-actions { + display: flex; + gap: 0.5rem; + align-items: center; + } + + .integration-toggle { + background: none; + border: none; + cursor: pointer; + padding: 0.25rem; + color: #6b7280; + transition: color 0.2s; + } + + .integration-toggle:hover { + color: #374151; + } + + .integration-toggle.expanded { + transform: rotate(180deg); + } + + .integration-content { + display: none; + padding: 1rem; + } + + .integration-content.expanded { + display: block; + } + + .integration-form { + display: grid; + gap: 1rem; + } + + .form-group { + display: flex; + flex-direction: column; + gap: 0.5rem; + } + + .form-group label { + font-size: 0.875rem; + font-weight: 500; + color: #374151; + } + + .form-group input, + .form-group select, + .form-group textarea { + padding: 0.5rem; + border: 1px solid #d1d5db; + border-radius: 0.375rem; + font-size: 0.875rem; + transition: border-color 0.2s; + } + + .form-group input:focus, + .form-group select:focus, + .form-group textarea:focus { + outline: none; + border-color: #3b82f6; + box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1); + } + + .form-group textarea { + resize: vertical; + min-height: 80px; + } + + .form-row { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 1rem; + } + + .rate-limits-section { + background-color: #f9fafb; + border: 1px solid #e5e7eb; + border-radius: 0.375rem; + padding: 1rem; + margin-top: 1rem; + } + + .rate-limits-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 1rem; + } + + .rate-limits-title { + font-weight: 600; + color: #374151; + } + + .rate-limit-item { + display: grid; + grid-template-columns: 1fr 1fr 1fr auto; + gap: 1rem; + align-items: end; + margin-bottom: 0.5rem; + } + + .rate-limit-item:last-child { + margin-bottom: 0; + } + + .rate-limit-actions { + display: flex; + gap: 0.5rem; + align-items: center; + } + + .btn-reset-rate-limit { + background-color: #f59e0b; + color: white; + border: none; + border-radius: 0.25rem; + padding: 0.5rem; + cursor: pointer; + font-size: 0.75rem; + transition: background-color 0.2s; + } + + .btn-reset-rate-limit:hover { + background-color: #d97706; + } + + .btn-remove-rate-limit { + background-color: #ef4444; + color: white; + border: none; + border-radius: 0.25rem; + padding: 0.5rem; + cursor: pointer; + font-size: 0.75rem; + transition: background-color 0.2s; + } + + .btn-remove-rate-limit:hover { + background-color: #dc2626; + } + + .btn-add-rate-limit { + background-color: #10b981; + color: white; + border: none; + border-radius: 0.25rem; + padding: 0.5rem 1rem; + cursor: pointer; + font-size: 0.875rem; + margin-top: 0.5rem; + transition: background-color 0.2s; + } + + .btn-add-rate-limit:hover { + background-color: #059669; + } + + .btn-reset-rate-limits { + background-color: #f59e0b; + color: white; + border: none; + border-radius: 0.25rem; + padding: 0.5rem 1rem; + cursor: pointer; + font-size: 0.875rem; + transition: background-color 0.2s; + } + + .btn-reset-rate-limits:hover { + background-color: #d97706; + } + + .models-section { + background-color: #f9fafb; + border: 1px solid #e5e7eb; + border-radius: 0.375rem; + padding: 1rem; + margin-top: 1rem; + } + + .models-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 1rem; + } + + .models-title { + font-weight: 600; + color: #374151; + } + + .model-item { + display: grid; + grid-template-columns: 1fr auto; + gap: 1rem; + align-items: center; + padding: 0.5rem; + background-color: white; + border: 1px solid #e5e7eb; + border-radius: 0.25rem; + margin-bottom: 0.5rem; + } + + .model-item:last-child { + margin-bottom: 0; + } + + .model-slug { + font-weight: 500; + color: #374151; + } + + .model-status { + padding: 0.25rem 0.5rem; + border-radius: 0.25rem; + font-size: 0.75rem; + font-weight: 500; + } + + .model-status.active { + background-color: #dcfce7; + color: #166534; + } + + .model-status.inactive { + background-color: #fee2e2; + color: #991b1b; + } + + .btn-add-model { + background-color: #10b981; + color: white; + border: none; + border-radius: 0.25rem; + padding: 0.5rem 1rem; + cursor: pointer; + font-size: 0.875rem; + margin-top: 0.5rem; + transition: background-color 0.2s; + } + + .btn-add-model:hover { + background-color: #059669; + } + + .integration-form-actions { + display: flex; + gap: 0.5rem; + justify-content: flex-end; + margin-top: 1rem; + padding-top: 1rem; + border-top: 1px solid #e5e7eb; + } + + .btn-save { + background-color: #3b82f6; + color: white; + } + + .btn-save:hover { + background-color: #2563eb; + } + + .btn-cancel { + background-color: #6b7280; + color: white; + } + + .btn-cancel:hover { + background-color: #4b5563; + } + + .btn-delete { + background-color: #ef4444; + color: white; + } + + .btn-delete:hover { + background-color: #dc2626; + } + + .credential-item { + display: grid; + grid-template-columns: 1fr 1fr auto; + gap: 1rem; + align-items: end; + margin-bottom: 0.5rem; + padding: 1rem; + background-color: white; + border: 1px solid #e5e7eb; + border-radius: 0.375rem; + } + + .credential-item:last-child { + margin-bottom: 0; + } + + .btn-add-credential { + background-color: #10b981; + color: white; + border: none; + border-radius: 0.25rem; + padding: 0.5rem 1rem; + cursor: pointer; + font-size: 0.875rem; + margin-top: 0.5rem; + transition: background-color 0.2s; + } + + .btn-add-credential:hover { + background-color: #059669; + } + + .btn-remove-credential { + background-color: #ef4444; + color: white; + border: none; + border-radius: 0.25rem; + padding: 0.5rem; + cursor: pointer; + font-size: 0.75rem; + transition: background-color 0.2s; + } + + .btn-remove-credential:hover { + background-color: #dc2626; + } + + /* Responsive adjustments */ + @media (max-width: 768px) { + .integrations-header { + flex-direction: column; + align-items: stretch; + gap: 1rem; + } + + .form-row { + grid-template-columns: 1fr; + } + + .rate-limit-item { + grid-template-columns: 1fr; + gap: 0.5rem; + } + + .credential-item { + grid-template-columns: 1fr; + gap: 0.5rem; + } + + .integration-form-actions { + flex-direction: column; + } + } + /* modal.css */ /* Modal styles */ .modal { @@ -885,6 +1314,7 @@ Real-time Logs +
+
+
+

Integrations

+ +
+
+
+
+ Loading integrations... +
+
+
+
+
-
-
-
- - -
-
- - -
-
- -
- -
- ${renderCredentials(integration.credentials, index)} -
-
- -
-
-
Rate Limits
+
+
+

Integration Configuration

+
${hasRateLimits ? ` - ` : ''} + + +
-
- ${hasRateLimits ? renderRateLimits(integration.rate_limits, index) : '
No rate limits configured
'} -
-
- - ${hasModels ? ` -
-
-
Models
-
-
- ${renderModels(integration.models, index)} -
- +
+ +
- ` : ''} - -
- - - +
+

Note: Edit the JSON configuration above. Make sure the JSON is valid before saving.

+

Common fields: provider, slug, credentials, rate_limits, models, allow_all_models

- +
`; @@ -2804,12 +2829,10 @@

Request Details

function addNewIntegration() { const newIntegration = { provider: 'openai', - slug: 'new-integration', + slug: crypto.randomUUID().substring(0, 8), credentials: { - apiKey: '' }, - rate_limits: [], - models: [] + allow_all_models: true, }; integrationsData.push(newIntegration); @@ -2824,72 +2847,44 @@

Request Details

// Save integration async function saveIntegration(index) { - const form = document.getElementById(`form-${index}`); - const formData = new FormData(form); - - // Collect form data - const integration = { - provider: formData.get('provider'), - slug: formData.get('slug'), - credentials: {}, - rate_limits: [], - models: integrationsData[index].models || [] - }; - - // Collect credentials - const credentialEntries = []; - for (const [key, value] of formData.entries()) { - if (key.startsWith('credentials.')) { - const parts = key.split('.'); - const credIndex = parseInt(parts[1]); - const field = parts[2]; - - if (!credentialEntries[credIndex]) { - credentialEntries[credIndex] = {}; - } - credentialEntries[credIndex][field] = value; + try { + const textarea = document.getElementById(`json-editor-${index}`); + const jsonText = textarea.value.trim(); + + if (!jsonText) { + alert('Please enter integration configuration'); + return; } - } - - credentialEntries.forEach(entry => { - if (entry.key && entry.value) { - integration.credentials[entry.key] = entry.value; + + // Parse and validate JSON + let integration; + try { + integration = JSON.parse(jsonText); + } catch (error) { + alert(`Invalid JSON: ${error.message}`); + return; } - }); - - // Collect rate limits - const rateLimitEntries = []; - for (const [key, value] of formData.entries()) { - if (key.startsWith('rate_limits.')) { - const parts = key.split('.'); - const limitIndex = parseInt(parts[1]); - const field = parts[2]; - - if (!rateLimitEntries[limitIndex]) { - rateLimitEntries[limitIndex] = {}; - } - rateLimitEntries[limitIndex][field] = value; + + // Basic validation + if (!integration.provider || !integration.slug) { + alert('Integration must have provider and slug fields'); + return; } + + // Update the integration data + integrationsData[index] = integration; + + // Save to file + await saveIntegrationsToFile(); + + // Re-render the integrations + renderIntegrations(); + + alert('Integration saved successfully!'); + } catch (error) { + console.error('Error saving integration:', error); + alert(`Failed to save integration: ${error.message}`); } - - rateLimitEntries.forEach(entry => { - if (entry.type && entry.value && entry.unit) { - integration.rate_limits.push({ - type: entry.type, - value: parseInt(entry.value), - unit: entry.unit - }); - } - }); - - // Update the integration data - integrationsData[index] = integration; - - // Save to settings.json - await saveIntegrationsToFile(); - - // Re-render to update the display - renderIntegrations(); } // Cancel edit @@ -2997,43 +2992,44 @@

Request Details

} } - // Reset individual rate limit - async function resetIndividualRateLimit(integrationIndex, rateLimitIndex) { - if (confirm('Are you sure you want to reset this specific rate limit?')) { - try { - // Get the integration and rate limit info - const integration = integrationsData[integrationIndex]; - const rateLimit = integration.rate_limits[rateLimitIndex]; - - // Create a cache key based on the integration and rate limit - // const cacheKey = `rate_limit:${integration.slug}:${rateLimit.type}:${rateLimit.unit}`; - - // Delete the cache entry for this rate limit - const response = await fetch(`/admin/integrations/ratelimit/${integration.slug}/${rateLimit.type}/reset`, { - method: 'PUT', - headers: { - 'Authorization': `Bearer ${adminApiKey}`, - 'x-admin-api-key': adminApiKey - } - }); - - if (!response.ok) { - if (response.status === 401) { - localStorage.removeItem('adminApiKey'); - adminApiKey = ''; - throw new Error('Invalid Admin API Key. Please refresh and try again.'); + // Reset all rate limits for an integration + async function resetAllRateLimits(integrationIndex) { + if (confirm('Are you sure you want to reset all rate limits for this integration?')) { + try { + const integration = integrationsData[integrationIndex]; + + if (!integration.rate_limits || integration.rate_limits.length === 0) { + alert('No rate limits configured for this integration'); + return; + } + + console.log('Resetting rate limits for integration:', integration.slug); + + const response = await fetch(`/admin/integrations/ratelimit/${integration.slug}/reset`, { + method: 'PUT', + headers: { + 'Authorization': `Bearer ${adminApiKey}`, + 'x-admin-api-key': adminApiKey + } + }); + + if (!response.ok) { + if (response.status === 401) { + localStorage.removeItem('adminApiKey'); + adminApiKey = ''; + throw new Error('Invalid Admin API Key. Please refresh and try again.'); + } + throw new Error(`Failed to reset rate limit: ${rateLimit.type}`); + } + + console.log(`All rate limits reset for integration ${integration.slug}`); + alert(`All rate limits reset successfully for ${integration.slug}`); + } catch (error) { + console.error('Error resetting rate limits:', error); + alert(`Failed to reset rate limits: ${error.message}`); + } + } } - throw new Error('Failed to reset rate limit cache'); - } - - console.log(`Rate limit cache cleared for integration ${integrationIndex}, rate limit ${rateLimitIndex}`); - alert(`Rate limit reset successfully for ${integration.slug} (${rateLimit.type} ${rateLimit.unit})`); - } catch (error) { - console.error('Error resetting rate limit:', error); - alert(`Failed to reset rate limit: ${error.message}`); - } - } - } // Add credential field function addCredential(index) { @@ -3070,7 +3066,7 @@

Request Details

} } - // Save integrations to settings.json file + // Save integrations to conf.json file async function saveIntegrationsToFile() { try { const settingsData = {