|
| 1 | +import { AiProvider } from '@prisma/client'; |
| 2 | + |
| 3 | +export type ProviderConfig = { |
| 4 | + provider: AiProvider; |
| 5 | + apiKey: string; |
| 6 | + endpointBase?: string; |
| 7 | + defaultModel?: string; |
| 8 | + orgId?: string; |
| 9 | +}; |
| 10 | + |
| 11 | +export interface AIProviderAdapter { |
| 12 | + chatJSON( |
| 13 | + prompt: string, |
| 14 | + opts?: { model?: string; system?: string } |
| 15 | + ): Promise<any>; |
| 16 | +} |
| 17 | + |
| 18 | +const DEFAULTS = { |
| 19 | + OPENAI: { base: 'https://api.openai.com/v1', model: 'gpt-4o-mini' }, |
| 20 | + ANTHROPIC: { |
| 21 | + base: 'https://api.anthropic.com', |
| 22 | + path: '/v1/messages', |
| 23 | + model: 'claude-3-5-sonnet-latest', |
| 24 | + version: process.env.ANTHROPIC_API_VERSION || '2023-06-01', |
| 25 | + }, |
| 26 | + GOOGLE: { |
| 27 | + base: 'https://generativelanguage.googleapis.com', |
| 28 | + model: 'gemini-1.5-flash', |
| 29 | + }, |
| 30 | + AZURE_OPENAI: { |
| 31 | + // endpointBase should be like: https://your-resource.openai.azure.com |
| 32 | + apiVersion: process.env.AZURE_OPENAI_API_VERSION || '2024-02-15-preview', |
| 33 | + }, |
| 34 | +}; |
| 35 | + |
| 36 | +async function requestWithTimeout( |
| 37 | + url: string, |
| 38 | + init: RequestInit, |
| 39 | + timeoutMs = 12_000 |
| 40 | +): Promise<Response> { |
| 41 | + const controller = new AbortController(); |
| 42 | + const id = setTimeout(() => controller.abort(), timeoutMs); |
| 43 | + try { |
| 44 | + const res = await fetch(url, { ...init, signal: controller.signal }); |
| 45 | + return res; |
| 46 | + } finally { |
| 47 | + clearTimeout(id); |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | +async function withRetry<T>(fn: () => Promise<T>, attempts = 2): Promise<T> { |
| 52 | + let lastErr: any; |
| 53 | + for (let i = 0; i < attempts; i++) { |
| 54 | + try { |
| 55 | + return await fn(); |
| 56 | + } catch (e) { |
| 57 | + lastErr = e; |
| 58 | + await new Promise((r) => setTimeout(r, (i + 1) * 300)); |
| 59 | + } |
| 60 | + } |
| 61 | + throw lastErr; |
| 62 | +} |
| 63 | + |
| 64 | +export function createAIProviderAdapter( |
| 65 | + cfg: ProviderConfig |
| 66 | +): AIProviderAdapter { |
| 67 | + if (cfg.provider === 'OPENAI') { |
| 68 | + const base = cfg.endpointBase || DEFAULTS.OPENAI.base; |
| 69 | + const model = cfg.defaultModel || DEFAULTS.OPENAI.model; |
| 70 | + return { |
| 71 | + async chatJSON(prompt, opts) { |
| 72 | + const sys = |
| 73 | + opts?.system || |
| 74 | + 'You are a helpful analysis assistant. Output strictly valid JSON only.'; |
| 75 | + const body = { |
| 76 | + model: opts?.model || model, |
| 77 | + messages: [ |
| 78 | + { role: 'system', content: sys }, |
| 79 | + { role: 'user', content: prompt }, |
| 80 | + ], |
| 81 | + temperature: 0.2, |
| 82 | + response_format: { type: 'json_object' }, |
| 83 | + }; |
| 84 | + const res = await withRetry(() => |
| 85 | + requestWithTimeout(`${base}/chat/completions`, { |
| 86 | + method: 'POST', |
| 87 | + headers: { |
| 88 | + 'Content-Type': 'application/json', |
| 89 | + Authorization: `Bearer ${cfg.apiKey}`, |
| 90 | + ...(cfg.orgId ? { 'OpenAI-Organization': cfg.orgId } : {}), |
| 91 | + }, |
| 92 | + body: JSON.stringify(body), |
| 93 | + }) |
| 94 | + ); |
| 95 | + if (!res.ok) throw new Error(`OpenAI error ${res.status}`); |
| 96 | + const data: any = await res.json(); |
| 97 | + const choices: any[] | undefined = Array.isArray((data as any)?.choices) |
| 98 | + ? (data as any).choices |
| 99 | + : undefined; |
| 100 | + const text = |
| 101 | + choices && choices[0] && choices[0].message?.content |
| 102 | + ? String(choices[0].message.content) |
| 103 | + : '{}'; |
| 104 | + try { |
| 105 | + return JSON.parse(text); |
| 106 | + } catch { |
| 107 | + return { raw: text }; |
| 108 | + } |
| 109 | + }, |
| 110 | + }; |
| 111 | + } |
| 112 | + |
| 113 | + if (cfg.provider === 'ANTHROPIC') { |
| 114 | + const base = cfg.endpointBase || DEFAULTS.ANTHROPIC.base; |
| 115 | + const path = DEFAULTS.ANTHROPIC.path; |
| 116 | + const model = cfg.defaultModel || DEFAULTS.ANTHROPIC.model; |
| 117 | + const version = DEFAULTS.ANTHROPIC.version; |
| 118 | + return { |
| 119 | + async chatJSON(prompt, _opts) { |
| 120 | + const body = { |
| 121 | + model, |
| 122 | + max_tokens: 1024, |
| 123 | + messages: [{ role: 'user', content: prompt }], |
| 124 | + } as any; |
| 125 | + const res = await withRetry(() => |
| 126 | + requestWithTimeout(`${base}${path}`, { |
| 127 | + method: 'POST', |
| 128 | + headers: { |
| 129 | + 'Content-Type': 'application/json', |
| 130 | + 'x-api-key': cfg.apiKey, |
| 131 | + 'anthropic-version': version, |
| 132 | + }, |
| 133 | + body: JSON.stringify(body), |
| 134 | + }) |
| 135 | + ); |
| 136 | + if (!res.ok) throw new Error(`Anthropic error ${res.status}`); |
| 137 | + const data: any = await res.json(); |
| 138 | + const text: string = |
| 139 | + data?.content?.[0]?.text || data?.content?.[0]?.content?.[0]?.text || '{}'; |
| 140 | + try { |
| 141 | + return JSON.parse(text); |
| 142 | + } catch { |
| 143 | + return { raw: text }; |
| 144 | + } |
| 145 | + }, |
| 146 | + }; |
| 147 | + } |
| 148 | + |
| 149 | + if (cfg.provider === 'GOOGLE') { |
| 150 | + const base = cfg.endpointBase || DEFAULTS.GOOGLE.base; |
| 151 | + const model = cfg.defaultModel || DEFAULTS.GOOGLE.model; |
| 152 | + return { |
| 153 | + async chatJSON(prompt, _opts) { |
| 154 | + // Google uses API key in query string |
| 155 | + const url = `${base}/v1beta/models/${encodeURIComponent(model)}:generateContent?key=${encodeURIComponent(cfg.apiKey)}`; |
| 156 | + const body = { |
| 157 | + contents: [{ role: 'user', parts: [{ text: prompt }] }], |
| 158 | + generationConfig: { temperature: 0.2 }, |
| 159 | + } as any; |
| 160 | + const res = await withRetry(() => |
| 161 | + requestWithTimeout(url, { |
| 162 | + method: 'POST', |
| 163 | + headers: { 'Content-Type': 'application/json' }, |
| 164 | + body: JSON.stringify(body), |
| 165 | + }) |
| 166 | + ); |
| 167 | + if (!res.ok) throw new Error(`Google error ${res.status}`); |
| 168 | + const data: any = await res.json(); |
| 169 | + const text: string = |
| 170 | + data?.candidates?.[0]?.content?.parts?.[0]?.text || '{}'; |
| 171 | + try { |
| 172 | + return JSON.parse(text); |
| 173 | + } catch { |
| 174 | + return { raw: text }; |
| 175 | + } |
| 176 | + }, |
| 177 | + }; |
| 178 | + } |
| 179 | + |
| 180 | + if (cfg.provider === 'AZURE_OPENAI') { |
| 181 | + const base = cfg.endpointBase?.replace(/\/$/, '') || ''; |
| 182 | + const deployment = cfg.defaultModel; // here defaultModel should be deployment name |
| 183 | + const apiVersion = DEFAULTS.AZURE_OPENAI.apiVersion; |
| 184 | + if (!base || !deployment) { |
| 185 | + return { |
| 186 | + async chatJSON() { |
| 187 | + throw new Error( |
| 188 | + 'Azure OpenAI requires endpointBase and defaultModel (deployment name)' |
| 189 | + ); |
| 190 | + }, |
| 191 | + }; |
| 192 | + } |
| 193 | + return { |
| 194 | + async chatJSON(prompt, opts) { |
| 195 | + const sys = |
| 196 | + opts?.system || |
| 197 | + 'You are a helpful analysis assistant. Output strictly valid JSON only.'; |
| 198 | + const body = { |
| 199 | + messages: [ |
| 200 | + { role: 'system', content: sys }, |
| 201 | + { role: 'user', content: prompt }, |
| 202 | + ], |
| 203 | + temperature: 0.2, |
| 204 | + response_format: { type: 'json_object' }, |
| 205 | + } as any; |
| 206 | + const url = `${base}/openai/deployments/${encodeURIComponent(deployment)}/chat/completions?api-version=${encodeURIComponent(apiVersion)}`; |
| 207 | + const res = await withRetry(() => |
| 208 | + requestWithTimeout(url, { |
| 209 | + method: 'POST', |
| 210 | + headers: { |
| 211 | + 'Content-Type': 'application/json', |
| 212 | + 'api-key': cfg.apiKey, |
| 213 | + }, |
| 214 | + body: JSON.stringify(body), |
| 215 | + }) |
| 216 | + ); |
| 217 | + if (!res.ok) throw new Error(`Azure OpenAI error ${res.status}`); |
| 218 | + const data: any = await res.json(); |
| 219 | + const text = data?.choices?.[0]?.message?.content || '{}'; |
| 220 | + try { |
| 221 | + return JSON.parse(text); |
| 222 | + } catch { |
| 223 | + return { raw: text }; |
| 224 | + } |
| 225 | + }, |
| 226 | + }; |
| 227 | + } |
| 228 | + |
| 229 | + // Custom HTTP adapter: POST to endpointBase with { prompt }, expect JSON response |
| 230 | + if (cfg.provider === 'CUSTOM') { |
| 231 | + const base = cfg.endpointBase || ''; |
| 232 | + return { |
| 233 | + async chatJSON(prompt) { |
| 234 | + if (!base) throw new Error('Custom provider requires endpointBase'); |
| 235 | + const res = await withRetry(() => |
| 236 | + requestWithTimeout(base, { |
| 237 | + method: 'POST', |
| 238 | + headers: { |
| 239 | + 'Content-Type': 'application/json', |
| 240 | + Authorization: `Bearer ${cfg.apiKey}`, |
| 241 | + }, |
| 242 | + body: JSON.stringify({ prompt, mode: 'json' }), |
| 243 | + }) |
| 244 | + ); |
| 245 | + if (!res.ok) throw new Error(`Custom provider error ${res.status}`); |
| 246 | + const data: any = await res.json(); |
| 247 | + // Assume provider already returns JSON object |
| 248 | + return data; |
| 249 | + }, |
| 250 | + }; |
| 251 | + } |
| 252 | + |
| 253 | + // Default unsupported provider |
| 254 | + return { |
| 255 | + async chatJSON() { |
| 256 | + throw new Error('Unsupported provider'); |
| 257 | + }, |
| 258 | + }; |
| 259 | +} |
0 commit comments