diff --git a/apps/api/src/env.ts b/apps/api/src/env.ts index bd4ed2238b..81cf085ffb 100644 --- a/apps/api/src/env.ts +++ b/apps/api/src/env.ts @@ -31,6 +31,7 @@ export const env = createEnv({ SLACK_SIGNING_SECRET: z.string().optional(), LOOPS_API_KEY: z.string().optional(), LOOPS_SLACK_CHANNEL_ID: z.string().optional(), + AI_SERVICE_URL: z.url().optional(), }, runtimeEnv: Bun.env, emptyStringAsUndefined: true, diff --git a/apps/api/src/index.ts b/apps/api/src/index.ts index 9fd4f727cb..d0f438063a 100644 --- a/apps/api/src/index.ts +++ b/apps/api/src/index.ts @@ -12,6 +12,9 @@ import { logger } from "hono/logger"; import { env } from "./env"; import type { AppBindings } from "./hono-bindings"; import { + forwardLlmToAi, + forwardSttListenToAi, + forwardSttTranscribeToAi, loadTestOverride, observabilityMiddleware, sentryMiddleware, @@ -50,13 +53,28 @@ app.use("*", (c, next) => { return corsMiddleware(c, next); }); -app.use("/chat/completions", loadTestOverride, supabaseAuthMiddleware); +app.use( + "/chat/completions", + loadTestOverride, + supabaseAuthMiddleware, + forwardLlmToAi, +); app.use("/webhook/stripe", verifyStripeWebhook); app.use("/webhook/slack/events", verifySlackWebhook); if (env.NODE_ENV !== "development") { - app.use("/listen", loadTestOverride, supabaseAuthMiddleware); - app.use("/transcribe", loadTestOverride, supabaseAuthMiddleware); + app.use( + "/listen", + loadTestOverride, + supabaseAuthMiddleware, + forwardSttListenToAi, + ); + app.use( + "/transcribe", + loadTestOverride, + supabaseAuthMiddleware, + forwardSttTranscribeToAi, + ); } app.route("/", routes); diff --git a/apps/api/src/listen.ts b/apps/api/src/listen.ts index cf6b3899db..c521adcaca 100644 --- a/apps/api/src/listen.ts +++ b/apps/api/src/listen.ts @@ -2,6 +2,7 @@ import * as Sentry from "@sentry/bun"; import type { Handler } from "hono"; import { upgradeWebSocket } from "hono/bun"; +import { env } from "./env"; import type { AppBindings } from "./hono-bindings"; import { createProxyFromRequest, @@ -9,6 +10,19 @@ import { WsProxyConnection, } from "./stt"; +function createAiServiceProxy( + clientUrl: URL, + reqHeaders: Headers, +): WsProxyConnection { + const aiServiceUrl = new URL("/stt/listen", env.AI_SERVICE_URL!); + aiServiceUrl.search = clientUrl.search; + + const authHeader = reqHeaders.get("authorization"); + const headers = authHeader ? { Authorization: authHeader } : undefined; + + return new WsProxyConnection(aiServiceUrl.toString(), { headers }); +} + export const listenSocketHandler: Handler = async (c, next) => { const emit = c.get("emit"); const userId = c.get("supabaseUserId"); @@ -18,7 +32,11 @@ export const listenSocketHandler: Handler = async (c, next) => { let connection: WsProxyConnection; try { - connection = createProxyFromRequest(clientUrl, c.req.raw.headers); + if (env.AI_SERVICE_URL) { + connection = createAiServiceProxy(clientUrl, c.req.raw.headers); + } else { + connection = createProxyFromRequest(clientUrl, c.req.raw.headers); + } await connection.preconnectUpstream(); emit({ type: "stt.websocket.connected", userId, provider }); } catch (error) { diff --git a/apps/api/src/middleware/ai-forward.ts b/apps/api/src/middleware/ai-forward.ts new file mode 100644 index 0000000000..24a3ec1848 --- /dev/null +++ b/apps/api/src/middleware/ai-forward.ts @@ -0,0 +1,91 @@ +import type { Context, Next } from "hono"; + +import { env } from "../env"; +import type { AppBindings } from "../hono-bindings"; + +const REQUEST_TIMEOUT_MS = 120_000; + +export async function forwardToAiService( + c: Context, + next: Next, + targetPath: string, +): Promise { + if (!env.AI_SERVICE_URL) { + return next(); + } + + const targetUrl = new URL(targetPath, env.AI_SERVICE_URL); + const clientUrl = new URL(c.req.url); + targetUrl.search = clientUrl.search; + + const authHeader = c.req.header("authorization"); + + const timeoutController = new AbortController(); + const timeoutId = setTimeout( + () => timeoutController.abort(), + REQUEST_TIMEOUT_MS, + ); + const signal = AbortSignal.any([c.req.raw.signal, timeoutController.signal]); + + try { + const response = await fetch(targetUrl.toString(), { + method: c.req.method, + headers: { + "Content-Type": c.req.header("content-type") ?? "application/json", + ...(authHeader ? { Authorization: authHeader } : {}), + }, + body: c.req.raw.body, + // @ts-expect-error - duplex is required for streaming request bodies + duplex: "half", + signal, + }); + + const contentType = response.headers.get("content-type") ?? ""; + if (contentType.includes("text/event-stream")) { + return new Response(response.body, { + status: response.status, + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + }, + }); + } + + return new Response(response.body, { + status: response.status, + headers: { "Content-Type": contentType || "application/json" }, + }); + } catch (error) { + if (signal.aborted) { + const isTimeout = timeoutController.signal.aborted; + return new Response( + isTimeout ? "Request timeout" : "Client disconnected", + { status: isTimeout ? 504 : 499 }, + ); + } + throw error; + } finally { + clearTimeout(timeoutId); + } +} + +export const forwardLlmToAi = async ( + c: Context, + next: Next, +): Promise => { + return forwardToAiService(c, next, "/llm/chat/completions"); +}; + +export const forwardSttListenToAi = async ( + c: Context, + next: Next, +): Promise => { + return forwardToAiService(c, next, "/stt/listen"); +}; + +export const forwardSttTranscribeToAi = async ( + c: Context, + next: Next, +): Promise => { + return forwardToAiService(c, next, "/stt/"); +}; diff --git a/apps/api/src/middleware/index.ts b/apps/api/src/middleware/index.ts index 39d6877f88..ee0ab42d87 100644 --- a/apps/api/src/middleware/index.ts +++ b/apps/api/src/middleware/index.ts @@ -1,3 +1,4 @@ +export * from "./ai-forward"; export * from "./load-test-auth"; export * from "./observability"; export * from "./sentry";