diff --git a/.github/actions/sentry_cli/action.yaml b/.github/actions/sentry_cli/action.yaml new file mode 100644 index 0000000000..b43d24b4e2 --- /dev/null +++ b/.github/actions/sentry_cli/action.yaml @@ -0,0 +1,11 @@ +inputs: + version: + required: false + default: "2.39.1" +runs: + using: "composite" + steps: + - run: curl -sL https://sentry.io/get-cli/ | SENTRY_CLI_VERSION="${{ inputs.version }}" sh + shell: bash + - run: sentry-cli --version + shell: bash diff --git a/.github/workflows/ai_cd.yaml b/.github/workflows/ai_cd.yaml new file mode 100644 index 0000000000..aa7038423e --- /dev/null +++ b/.github/workflows/ai_cd.yaml @@ -0,0 +1,93 @@ +on: + workflow_dispatch: + +jobs: + compute-version: + runs-on: ubuntu-latest + outputs: + version: ${{ steps.version.outputs.version }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + - run: git fetch --tags --force + - uses: ./.github/actions/doxxer_install + - id: version + run: | + VERSION=$(doxxer --config doxxer.ai.toml next patch) + echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "Computed version: $VERSION" + + build: + needs: compute-version + runs-on: depot-ubuntu-24.04-8 + timeout-minutes: 60 + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/rust_install + with: + platform: linux + - run: | + cargo build --release -p ai + objcopy --only-keep-debug target/release/ai target/release/ai.debug + objcopy --strip-debug --strip-unneeded target/release/ai + objcopy --add-gnu-debuglink=target/release/ai.debug target/release/ai + env: + CARGO_PROFILE_RELEASE_DEBUG: "true" + - uses: actions/upload-artifact@v4 + with: + name: ai-binary + path: target/release/ai + retention-days: 1 + - uses: actions/upload-artifact@v4 + with: + name: ai-debug + path: target/release/ai.debug + retention-days: 1 + + sentry-upload: + needs: [compute-version, build] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: ai-debug + path: ./debug + - uses: ./.github/actions/sentry_cli + - run: sentry-cli debug-files upload --include-sources ./debug + env: + SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} + SENTRY_ORG: ${{ secrets.SENTRY_ORG }} + SENTRY_PROJECT: ${{ secrets.SENTRY_PROJECT }} + + deploy: + needs: [compute-version, build, sentry-upload] + runs-on: depot-ubuntu-24.04-8 + timeout-minutes: 60 + concurrency: ai-fly-deploy + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: ai-binary + path: ./apps/ai/bin + - run: chmod +x ./apps/ai/bin/ai + - uses: superfly/flyctl-actions/setup-flyctl@master + - run: flyctl deploy --config apps/ai/fly.toml --dockerfile apps/ai/Dockerfile --remote-only -e APP_VERSION=${{ needs.compute-version.outputs.version }} + env: + FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }} + + tag: + needs: [compute-version, deploy] + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + - uses: mathieudutour/github-tag-action@v6.2 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + custom_tag: ai_v${{ needs.compute-version.outputs.version }} + tag_prefix: "" diff --git a/.github/workflows/ai_ci.yaml b/.github/workflows/ai_ci.yaml new file mode 100644 index 0000000000..86631f4b8c --- /dev/null +++ b/.github/workflows/ai_ci.yaml @@ -0,0 +1,23 @@ +on: + workflow_dispatch: + push: + branches: + - main + paths: + - apps/ai/** + - crates/llm-proxy/** + - crates/transcribe-proxy/** + pull_request: + branches: + - main + paths: + - apps/ai/** + - crates/llm-proxy/** + - crates/transcribe-proxy/** +jobs: + ci: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/rust_install + - run: cargo check -p ai diff --git a/Cargo.lock b/Cargo.lock index c681d2f79f..ec9f6234f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -245,6 +245,28 @@ dependencies = [ "memchr", ] +[[package]] +name = "ai" +version = "0.1.0" +dependencies = [ + "axum 0.8.7", + "dotenvy", + "jsonwebtoken", + "llm-proxy", + "owhisper-providers", + "reqwest", + "sentry", + "serde", + "serde_json", + "tokio", + "tower 0.5.2", + "tower-http 0.6.8", + "tracing", + "tracing-subscriber", + "transcribe-proxy", + "url", +] + [[package]] name = "aide" version = "0.15.1" @@ -9956,6 +9978,24 @@ dependencies = [ "tokio", ] +[[package]] +name = "llm-proxy" +version = "0.1.0" +dependencies = [ + "analytics", + "axum 0.8.7", + "bytes", + "futures-util", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", + "tokio-stream", + "tower 0.5.2", + "tracing", +] + [[package]] name = "local-waker" version = "0.1.4" @@ -14786,6 +14826,7 @@ dependencies = [ "sentry-core", "sentry-debug-images", "sentry-panic", + "sentry-tower", "sentry-tracing", "tokio", "ureq 3.1.4", @@ -14873,6 +14914,21 @@ dependencies = [ "thiserror 2.0.17", ] +[[package]] +name = "sentry-tower" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a303d0127d95ae928a937dcc0886931d28b4186e7338eea7d5786827b69b002" +dependencies = [ + "axum 0.8.7", + "http 1.4.0", + "pin-project", + "sentry-core", + "tower-layer", + "tower-service", + "url", +] + [[package]] name = "sentry-tracing" version = "0.42.0" @@ -16031,26 +16087,6 @@ dependencies = [ "syn 2.0.111", ] -[[package]] -name = "stt-server" -version = "0.1.0" -dependencies = [ - "axum 0.8.7", - "dotenvy", - "jsonwebtoken", - "owhisper-providers", - "reqwest", - "sentry", - "serde", - "serde_json", - "tokio", - "tower-http 0.6.8", - "tracing", - "tracing-subscriber", - "transcribe-proxy", - "url", -] - [[package]] name = "subtle" version = "2.6.1" @@ -18984,14 +19020,21 @@ dependencies = [ name = "transcribe-proxy" version = "0.1.0" dependencies = [ + "analytics", "axum 0.8.7", "bytes", "futures-util", + "owhisper-providers", + "reqwest", + "serde", + "serde_json", "thiserror 2.0.17", "tokio", "tokio-tungstenite 0.26.2", "tower 0.5.2", "tracing", + "url", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3843fbe3cd..5581de3dc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ debug = false [workspace] resolver = "2" members = [ + "apps/ai", "apps/desktop/src-tauri", "apps/granola", - "apps/stt", "crates/*", "owhisper/*", "plugins/*", @@ -48,6 +48,7 @@ hypr-language = { path = "crates/language", package = "language" } hypr-llama = { path = "crates/llama", package = "llama" } hypr-llm = { path = "crates/llm", package = "llm" } hypr-llm-interface = { path = "crates/llm-interface", package = "llm-interface" } +hypr-llm-proxy = { path = "crates/llm-proxy", package = "llm-proxy" } hypr-loops = { path = "crates/loops", package = "loops" } hypr-mac = { path = "crates/mac", package = "mac" } hypr-moonshine = { path = "crates/moonshine", package = "moonshine" } diff --git a/apps/stt/Cargo.toml b/apps/ai/Cargo.toml similarity index 79% rename from apps/stt/Cargo.toml rename to apps/ai/Cargo.toml index 1e1c63ba8b..a035e2cfd6 100644 --- a/apps/stt/Cargo.toml +++ b/apps/ai/Cargo.toml @@ -1,13 +1,10 @@ [package] -name = "stt-server" +name = "ai" version = "0.1.0" edition = "2024" -[[bin]] -name = "stt" -path = "src/main.rs" - [dependencies] +hypr-llm-proxy = { workspace = true } hypr-transcribe-proxy = { workspace = true } owhisper-providers = { workspace = true } @@ -24,4 +21,5 @@ url = { workspace = true } dotenvy = { workspace = true } jsonwebtoken = { workspace = true } -sentry = { workspace = true } +sentry = { workspace = true, features = ["tower", "tower-axum-matched-path", "tracing"] } +tower = { workspace = true } diff --git a/apps/ai/Dockerfile b/apps/ai/Dockerfile new file mode 100644 index 0000000000..8cab847cca --- /dev/null +++ b/apps/ai/Dockerfile @@ -0,0 +1,9 @@ +FROM debian:bookworm-slim + +RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* + +WORKDIR /app +COPY bin/ai /app/ai + +EXPOSE 3000 +CMD ["/app/ai"] diff --git a/apps/ai/fly.toml b/apps/ai/fly.toml new file mode 100644 index 0000000000..c152036c8c --- /dev/null +++ b/apps/ai/fly.toml @@ -0,0 +1,42 @@ +app = 'hyprnote-ai' +primary_region = 'sjc' +kill_signal = 'SIGTERM' +kill_timeout = 30 +swap_size_mb = 512 + +[build] +dockerfile = "Dockerfile" + +[deploy] +strategy = "bluegreen" + +[env] +PORT = "3000" +SENTRY_ENVIRONMENT = "production" + +[http_service] +processes = ['app'] +internal_port = 3000 +force_https = true +auto_stop_machines = 'stop' +auto_start_machines = true +min_machines_running = 1 + +[http_service.concurrency] +type = "connections" +hard_limit = 200 +soft_limit = 150 + +[[http_service.checks]] +grace_period = "20s" +interval = "15s" +method = "GET" +path = "/llm/completions" +protocol = "http" +timeout = "4s" + +[[vm]] +processes = ['app'] +memory = '1gb' +cpu_kind = 'shared' +cpus = 1 diff --git a/apps/stt/src/auth.rs b/apps/ai/src/auth.rs similarity index 65% rename from apps/stt/src/auth.rs rename to apps/ai/src/auth.rs index ecd6e20952..93143c623c 100644 --- a/apps/stt/src/auth.rs +++ b/apps/ai/src/auth.rs @@ -33,8 +33,13 @@ struct Claims { entitlements: Vec, } +enum JwksState { + Available(JwkSet), + Empty, +} + struct CachedJwks { - jwks: JwkSet, + state: JwksState, fetched_at: Instant, } @@ -45,14 +50,17 @@ fn jwks_cache() -> &'static Arc>> { JWKS_CACHE.get_or_init(|| Arc::new(RwLock::new(None))) } -async fn get_jwks() -> Result { +async fn get_jwks() -> Result { let cache = jwks_cache(); { let guard = cache.read().await; if let Some(cached) = guard.as_ref() { if cached.fetched_at.elapsed() < JWKS_CACHE_TTL { - return Ok(cached.jwks.clone()); + return Ok(match &cached.state { + JwksState::Available(jwks) => JwksState::Available(jwks.clone()), + JwksState::Empty => JwksState::Empty, + }); } } } @@ -67,15 +75,24 @@ async fn get_jwks() -> Result { .await .map_err(|_| "failed to parse jwks")?; + let state = if jwks.keys.is_empty() { + JwksState::Empty + } else { + JwksState::Available(jwks) + }; + { let mut guard = cache.write().await; *guard = Some(CachedJwks { - jwks: jwks.clone(), + state: match &state { + JwksState::Available(jwks) => JwksState::Available(jwks.clone()), + JwksState::Empty => JwksState::Empty, + }, fetched_at: Instant::now(), }); } - Ok(jwks) + Ok(state) } impl FromRequestParts for AuthUser @@ -96,6 +113,33 @@ where .or_else(|| auth_header.strip_prefix("bearer ")) .ok_or((StatusCode::UNAUTHORIZED, "invalid authorization header"))?; + let jwks_state = get_jwks() + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + + let jwks = match jwks_state { + JwksState::Empty => { + if cfg!(debug_assertions) { + tracing::warn!( + target: "security", + "JWKS empty in debug build: accepting unsigned token" + ); + let claims = decode_claims_insecure(token) + .map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token"))?; + return Ok(AuthUser { + user_id: claims.sub, + entitlements: claims.entitlements, + }); + } + tracing::error!( + target: "security", + "JWKS empty in release build: rejecting request" + ); + return Err((StatusCode::UNAUTHORIZED, "authentication unavailable")); + } + JwksState::Available(jwks) => jwks, + }; + let header = decode_header(token).map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token header"))?; @@ -104,10 +148,6 @@ where .as_ref() .ok_or((StatusCode::UNAUTHORIZED, "missing kid in token"))?; - let jwks = get_jwks() - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; - let jwk = jwks .find(kid) .ok_or((StatusCode::UNAUTHORIZED, "unknown signing key"))?; @@ -134,3 +174,9 @@ where }) } } + +fn decode_claims_insecure(token: &str) -> Result { + jsonwebtoken::dangerous::insecure_decode::(token) + .map(|data| data.claims) + .map_err(|_| ()) +} diff --git a/apps/stt/src/env.rs b/apps/ai/src/env.rs similarity index 82% rename from apps/stt/src/env.rs rename to apps/ai/src/env.rs index 4c376c0d3d..5ded8f9050 100644 --- a/apps/stt/src/env.rs +++ b/apps/ai/src/env.rs @@ -6,7 +6,9 @@ use owhisper_providers::Provider; pub struct Env { pub port: u16, pub sentry_dsn: Option, + pub sentry_environment: Option, pub supabase_url: String, + pub openrouter_api_key: String, api_keys: HashMap, } @@ -37,16 +39,15 @@ impl Env { Self { port: parse_or("PORT", 3000), sentry_dsn: optional("SENTRY_DSN"), + sentry_environment: optional("SENTRY_ENVIRONMENT"), supabase_url: required("SUPABASE_URL"), + openrouter_api_key: required("OPENROUTER_API_KEY"), api_keys, } } - pub fn api_key_for(&self, provider: Provider) -> String { - self.api_keys - .get(&provider) - .cloned() - .unwrap_or_else(|| panic!("{} is not configured", provider.env_key_name())) + pub fn api_keys(&self) -> HashMap { + self.api_keys.clone() } } diff --git a/apps/ai/src/main.rs b/apps/ai/src/main.rs new file mode 100644 index 0000000000..6c18aec480 --- /dev/null +++ b/apps/ai/src/main.rs @@ -0,0 +1,76 @@ +mod auth; +mod env; + +use std::net::SocketAddr; +use std::time::Duration; + +use axum::{Router, body::Body, http::Request}; +use sentry::integrations::tower::{NewSentryLayer, SentryHttpLayer}; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; + +use env::env; + +fn app() -> Router { + let llm_config = hypr_llm_proxy::LlmProxyConfig::new(&env().openrouter_api_key); + let stt_config = hypr_transcribe_proxy::SttProxyConfig::new(env().api_keys()); + + Router::new() + .nest("/stt", hypr_transcribe_proxy::router(stt_config)) + .nest("/llm", hypr_llm_proxy::router(llm_config)) + .layer( + ServiceBuilder::new() + .layer(NewSentryLayer::>::new_from_top()) + .layer(SentryHttpLayer::new().enable_transaction()) + .layer(TraceLayer::new_for_http()), + ) +} + +fn main() -> std::io::Result<()> { + let env = env(); + + let _guard = sentry::init(sentry::ClientOptions { + dsn: env.sentry_dsn.as_ref().and_then(|s| s.parse().ok()), + release: sentry::release_name!(), + environment: env.sentry_environment.clone().map(Into::into), + traces_sample_rate: 1.0, + send_default_pii: true, + auto_session_tracking: true, + session_mode: sentry::SessionMode::Request, + ..Default::default() + }); + + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info,tower_http=debug".into()), + ) + .init(); + + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()? + .block_on(async { + let addr = SocketAddr::from(([0, 0, 0, 0], env.port)); + tracing::info!("listening on {}", addr); + + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + axum::serve(listener, app()) + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap(); + }); + + if let Some(client) = sentry::Hub::current().client() { + client.close(Some(Duration::from_secs(2))); + } + + Ok(()) +} + +async fn shutdown_signal() { + tokio::signal::ctrl_c() + .await + .expect("failed to install CTRL+C signal handler"); + tracing::info!("shutting down"); +} diff --git a/apps/stt/src/main.rs b/apps/stt/src/main.rs deleted file mode 100644 index 5e5dc1353e..0000000000 --- a/apps/stt/src/main.rs +++ /dev/null @@ -1,50 +0,0 @@ -mod auth; -mod env; -mod handlers; - -use std::net::SocketAddr; - -use axum::{Router, routing::any}; -use tower_http::trace::TraceLayer; - -use env::env; -use handlers::ws_handler; - -fn app() -> Router { - Router::new() - .route("/listen", any(ws_handler)) - .layer(TraceLayer::new_for_http()) -} - -#[tokio::main] -async fn main() { - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "info,tower_http=debug".into()), - ) - .init(); - - let env = env(); - - let _guard = sentry::init(sentry::ClientOptions { - dsn: env.sentry_dsn.as_ref().and_then(|s| s.parse().ok()), - ..Default::default() - }); - - let addr = SocketAddr::from(([0, 0, 0, 0], env.port)); - tracing::info!("listening on {}", addr); - - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); - axum::serve(listener, app()) - .with_graceful_shutdown(shutdown_signal()) - .await - .unwrap(); -} - -async fn shutdown_signal() { - tokio::signal::ctrl_c() - .await - .expect("failed to install CTRL+C signal handler"); - tracing::info!("shutting down"); -} diff --git a/crates/llm-proxy/Cargo.toml b/crates/llm-proxy/Cargo.toml new file mode 100644 index 0000000000..5cf8c3c183 --- /dev/null +++ b/crates/llm-proxy/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "llm-proxy" +version = "0.1.0" +edition = "2024" + +[dependencies] +hypr-analytics = { workspace = true } + +axum = { workspace = true } +futures-util = { workspace = true } +reqwest = { workspace = true, features = ["json", "stream"] } +tokio = { workspace = true, features = ["rt-multi-thread", "time", "sync", "macros"] } +tokio-stream = { workspace = true } +tracing = { workspace = true } + +bytes = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } + +[dev-dependencies] +tower = { workspace = true, features = ["util"] } diff --git a/crates/llm-proxy/src/analytics.rs b/crates/llm-proxy/src/analytics.rs new file mode 100644 index 0000000000..840d31ff64 --- /dev/null +++ b/crates/llm-proxy/src/analytics.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +use hypr_analytics::{AnalyticsClient, AnalyticsPayload}; +use reqwest::Client; +use serde::Deserialize; + +use crate::types::OPENROUTER_URL; + +#[derive(Debug, Clone)] +pub struct GenerationEvent { + pub generation_id: String, + pub model: String, + pub input_tokens: u32, + pub output_tokens: u32, + pub latency: f64, + pub http_status: u16, + pub total_cost: Option, +} + +pub trait AnalyticsReporter: Send + Sync { + fn report_generation( + &self, + event: GenerationEvent, + ) -> std::pin::Pin + Send + '_>>; +} + +impl AnalyticsReporter for AnalyticsClient { + fn report_generation( + &self, + event: GenerationEvent, + ) -> std::pin::Pin + Send + '_>> { + Box::pin(async move { + let payload = AnalyticsPayload::builder("$ai_generation") + .with("$ai_provider", "openrouter") + .with("$ai_model", event.model.clone()) + .with("$ai_input_tokens", event.input_tokens) + .with("$ai_output_tokens", event.output_tokens) + .with("$ai_latency", event.latency) + .with("$ai_trace_id", event.generation_id.clone()) + .with("$ai_http_status", event.http_status) + .with("$ai_base_url", OPENROUTER_URL); + + let payload = if let Some(cost) = event.total_cost { + payload.with("$ai_total_cost_usd", cost) + } else { + payload + }; + + let _ = self.event(event.generation_id, payload.build()).await; + }) + } +} + +pub async fn send_generation_event( + analytics: &Arc, + generation_id: String, + model: String, + input_tokens: u32, + output_tokens: u32, + latency: f64, + http_status: u16, + total_cost: Option, +) { + let event = GenerationEvent { + generation_id, + model, + input_tokens, + output_tokens, + latency, + http_status, + total_cost, + }; + analytics.report_generation(event).await; +} + +pub async fn fetch_generation_metadata( + client: &Client, + api_key: &str, + generation_id: &str, +) -> Option { + #[derive(Deserialize)] + struct OpenRouterGenerationResponse { + data: OpenRouterGenerationData, + } + + #[derive(Deserialize)] + struct OpenRouterGenerationData { + total_cost: f64, + } + + let url = format!( + "https://openrouter.ai/api/v1/generation?id={}", + generation_id + ); + + let response = client + .get(&url) + .header("Authorization", format!("Bearer {}", api_key)) + .send() + .await + .ok()?; + + if !response.status().is_success() { + tracing::warn!( + status = %response.status(), + "failed to fetch generation metadata" + ); + return None; + } + + let data: OpenRouterGenerationResponse = response.json().await.ok()?; + Some(data.data.total_cost) +} diff --git a/crates/llm-proxy/src/config.rs b/crates/llm-proxy/src/config.rs new file mode 100644 index 0000000000..855c1b07c6 --- /dev/null +++ b/crates/llm-proxy/src/config.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; +use std::time::Duration; + +use crate::analytics::AnalyticsReporter; + +const DEFAULT_TIMEOUT_MS: u64 = 120_000; + +#[derive(Clone)] +pub struct LlmProxyConfig { + pub api_key: String, + pub timeout: Duration, + pub models_tool_calling: Vec, + pub models_default: Vec, + pub analytics: Option>, +} + +impl LlmProxyConfig { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS), + models_tool_calling: vec![ + "moonshotai/kimi-k2-0905:exacto".into(), + "anthropic/claude-haiku-4.5".into(), + "openai/gpt-oss-120b:exacto".into(), + ], + models_default: vec![ + "moonshotai/kimi-k2-0905".into(), + "openai/gpt-5.1-chat".into(), + ], + analytics: None, + } + } + + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn with_models_tool_calling(mut self, models: Vec) -> Self { + self.models_tool_calling = models; + self + } + + pub fn with_models_default(mut self, models: Vec) -> Self { + self.models_default = models; + self + } + + pub fn with_analytics(mut self, reporter: Arc) -> Self { + self.analytics = Some(reporter); + self + } +} diff --git a/crates/llm-proxy/src/error.rs b/crates/llm-proxy/src/error.rs new file mode 100644 index 0000000000..654fc05660 --- /dev/null +++ b/crates/llm-proxy/src/error.rs @@ -0,0 +1,11 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("request failed: {0}")] + Request(#[from] reqwest::Error), + #[error("upstream error: {status} - {body}")] + Upstream { status: u16, body: String }, + #[error("timeout")] + Timeout, + #[error("client disconnected")] + ClientDisconnected, +} diff --git a/crates/llm-proxy/src/handler.rs b/crates/llm-proxy/src/handler.rs new file mode 100644 index 0000000000..b426d74947 --- /dev/null +++ b/crates/llm-proxy/src/handler.rs @@ -0,0 +1,280 @@ +use std::sync::Arc; +use std::time::Instant; + +use axum::{ + Json, Router, + body::Body, + extract::State, + http::StatusCode, + response::{IntoResponse, Response}, + routing::post, +}; +use bytes::Bytes; +use futures_util::StreamExt; +use reqwest::Client; + +use crate::analytics::{AnalyticsReporter, fetch_generation_metadata, send_generation_event}; +use crate::config::LlmProxyConfig; +use crate::types::{ + ChatCompletionRequest, OPENROUTER_URL, OpenRouterRequest, OpenRouterResponse, Provider, + ToolChoice, +}; + +#[derive(Clone)] +struct AppState { + config: LlmProxyConfig, + client: Client, +} + +pub fn router(config: LlmProxyConfig) -> Router { + let state = AppState { + config, + client: Client::new(), + }; + + Router::new() + .route("/completions", post(completions_handler)) + .with_state(state) +} + +#[derive(Default)] +struct StreamMetadata { + generation_id: Option, + model: Option, + input_tokens: u32, + output_tokens: u32, +} + +fn extract_stream_metadata(chunk: &[u8], metadata: &mut StreamMetadata) { + let Ok(text) = std::str::from_utf8(chunk) else { + return; + }; + + for line in text.lines() { + let Some(data) = line.strip_prefix("data: ") else { + continue; + }; + + if data.trim() == "[DONE]" { + continue; + } + + let Ok(parsed) = serde_json::from_str::(data) else { + continue; + }; + + if metadata.generation_id.is_none() { + metadata.generation_id = parsed.get("id").and_then(|v| v.as_str()).map(String::from); + } + + if metadata.model.is_none() { + metadata.model = parsed + .get("model") + .and_then(|v| v.as_str()) + .map(String::from); + } + + if let Some(usage) = parsed.get("usage") { + if let Some(pt) = usage.get("prompt_tokens").and_then(|v| v.as_u64()) { + metadata.input_tokens = pt as u32; + } + if let Some(ct) = usage.get("completion_tokens").and_then(|v| v.as_u64()) { + metadata.output_tokens = ct as u32; + } + } + } +} + +async fn completions_handler( + State(state): State, + Json(request): Json, +) -> Response { + let start_time = Instant::now(); + + let needs_tool_calling = request.tools.as_ref().is_some_and(|t| !t.is_empty()) + && !matches!(&request.tool_choice, Some(ToolChoice::String(s)) if s == "none"); + + let models = if needs_tool_calling { + state.config.models_tool_calling.clone() + } else { + state.config.models_default.clone() + }; + + let stream = request.stream.unwrap_or(false); + + let openrouter_request = OpenRouterRequest { + messages: request.messages, + tools: request.tools, + tool_choice: request.tool_choice, + temperature: request.temperature, + max_tokens: request.max_tokens, + stream, + models, + provider: Provider { sort: "latency" }, + extra: request.extra, + }; + + let result = tokio::time::timeout(state.config.timeout, async { + state + .client + .post(OPENROUTER_URL) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", state.config.api_key)) + .json(&openrouter_request) + .send() + .await + }) + .await; + + let response = match result { + Ok(Ok(resp)) => resp, + Ok(Err(e)) => { + tracing::error!(error = %e, "upstream request failed"); + return (StatusCode::BAD_GATEWAY, e.to_string()).into_response(); + } + Err(_) => { + tracing::error!("upstream request timeout"); + return (StatusCode::GATEWAY_TIMEOUT, "Request timeout").into_response(); + } + }; + + let status = response.status(); + let http_status = status.as_u16(); + + if stream { + handle_stream_response(state, response, start_time, http_status).await + } else { + handle_non_stream_response(state, response, start_time, http_status).await + } +} + +async fn handle_stream_response( + state: AppState, + response: reqwest::Response, + start_time: Instant, + http_status: u16, +) -> Response { + let status = response.status(); + let analytics: Option> = state.config.analytics.clone(); + let api_key = state.config.api_key.clone(); + let client = state.client.clone(); + + let stream = response.bytes_stream(); + let (tx, rx) = tokio::sync::mpsc::channel::>(32); + + tokio::spawn(async move { + let mut metadata = StreamMetadata::default(); + + futures_util::pin_mut!(stream); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if analytics.is_some() { + extract_stream_metadata(&chunk, &mut metadata); + } + + if tx.send(Ok(chunk)).await.is_err() { + break; + } + } + Err(e) => { + let _ = tx + .send(Err(std::io::Error::new(std::io::ErrorKind::Other, e))) + .await; + break; + } + } + } + + let latency = start_time.elapsed().as_secs_f64(); + + if let Some(analytics) = analytics { + if let Some(gen_id) = metadata.generation_id { + let total_cost = fetch_generation_metadata(&client, &api_key, &gen_id).await; + + send_generation_event( + &analytics, + gen_id, + metadata.model.unwrap_or_default(), + metadata.input_tokens, + metadata.output_tokens, + latency, + http_status, + total_cost, + ) + .await; + } + } + }); + + let body = Body::from_stream(tokio_stream::wrappers::ReceiverStream::new(rx)); + Response::builder() + .status(status) + .header("Content-Type", "text/event-stream") + .header("Cache-Control", "no-cache") + .body(body) + .unwrap() +} + +async fn handle_non_stream_response( + state: AppState, + response: reqwest::Response, + start_time: Instant, + http_status: u16, +) -> Response { + let status = response.status(); + + let body_bytes = match response.bytes().await { + Ok(b) => b, + Err(e) => { + tracing::error!(error = %e, "failed to read response body"); + return (StatusCode::BAD_GATEWAY, "Failed to read response").into_response(); + } + }; + + let latency = start_time.elapsed().as_secs_f64(); + + if let Some(analytics) = &state.config.analytics { + if let Ok(parsed) = serde_json::from_slice::(&body_bytes) { + let client = state.client.clone(); + let api_key = state.config.api_key.clone(); + let analytics = analytics.clone(); + let generation_id = parsed.id.clone(); + + let input_tokens = parsed + .usage + .as_ref() + .and_then(|u| u.prompt_tokens) + .unwrap_or(0); + let output_tokens = parsed + .usage + .as_ref() + .and_then(|u| u.completion_tokens) + .unwrap_or(0); + let model = parsed.model.clone().unwrap_or_default(); + + tokio::spawn(async move { + let total_cost = fetch_generation_metadata(&client, &api_key, &generation_id).await; + + send_generation_event( + &analytics, + generation_id, + model, + input_tokens, + output_tokens, + latency, + http_status, + total_cost, + ) + .await; + }); + } + } + + Response::builder() + .status(status) + .header("Content-Type", "application/json") + .body(Body::from(body_bytes)) + .unwrap() +} diff --git a/crates/llm-proxy/src/lib.rs b/crates/llm-proxy/src/lib.rs new file mode 100644 index 0000000000..42beb5fffd --- /dev/null +++ b/crates/llm-proxy/src/lib.rs @@ -0,0 +1,144 @@ +mod analytics; +mod config; +mod error; +mod handler; +mod types; + +pub use analytics::{AnalyticsReporter, GenerationEvent}; +pub use config::*; +pub use error::*; +pub use handler::router; + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use axum::body::Body; + use axum::http::{Request, StatusCode}; + use tower::ServiceExt; + + use super::*; + + #[derive(Default, Clone)] + struct MockAnalytics { + events: Arc>>, + } + + impl AnalyticsReporter for MockAnalytics { + fn report_generation( + &self, + event: GenerationEvent, + ) -> std::pin::Pin + Send + '_>> { + let events = self.events.clone(); + Box::pin(async move { + events.lock().unwrap().push(event); + }) + } + } + + #[ignore] + #[tokio::test] + async fn e2e_completions_with_mock_analytics() { + let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set"); + + let mock_analytics = MockAnalytics::default(); + let events = mock_analytics.events.clone(); + + let config = LlmProxyConfig::new(api_key) + .with_models_default(vec!["openai/gpt-4.1-nano".into()]) + .with_analytics(Arc::new(mock_analytics)); + + let app = router(config); + + let request_body = serde_json::json!({ + "messages": [ + {"role": "user", "content": "Say 'hello' and nothing else."} + ], + "max_tokens": 10 + }); + + let request = Request::builder() + .method("POST") + .uri("/completions") + .header("Content-Type", "application/json") + .body(Body::from(serde_json::to_string(&request_body).unwrap())) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + + assert!(body.get("id").is_some()); + assert!(body.get("choices").is_some()); + + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + + let captured_events = events.lock().unwrap(); + assert_eq!(captured_events.len(), 1); + + let event = &captured_events[0]; + assert!(!event.generation_id.is_empty()); + assert!(!event.model.is_empty()); + assert_eq!(event.http_status, 200); + assert!(event.input_tokens > 0); + assert!(event.output_tokens > 0); + assert!(event.latency > 0.0); + } + + #[ignore] + #[tokio::test] + async fn e2e_completions_stream_with_mock_analytics() { + let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set"); + + let mock_analytics = MockAnalytics::default(); + let events = mock_analytics.events.clone(); + + let config = LlmProxyConfig::new(api_key) + .with_models_default(vec!["openai/gpt-4.1-nano".into()]) + .with_analytics(Arc::new(mock_analytics)); + + let app = router(config); + + let request_body = serde_json::json!({ + "messages": [ + {"role": "user", "content": "Say 'hello' and nothing else."} + ], + "stream": true, + "max_tokens": 10 + }); + + let request = Request::builder() + .method("POST") + .uri("/completions") + .header("Content-Type", "application/json") + .body(Body::from(serde_json::to_string(&request_body).unwrap())) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let body_str = String::from_utf8_lossy(&body_bytes); + + assert!(body_str.contains("data: ")); + + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + + let captured_events = events.lock().unwrap(); + assert_eq!(captured_events.len(), 1); + + let event = &captured_events[0]; + assert!(!event.generation_id.is_empty()); + assert!(!event.model.is_empty()); + assert_eq!(event.http_status, 200); + assert!(event.latency > 0.0); + } +} diff --git a/crates/llm-proxy/src/types.rs b/crates/llm-proxy/src/types.rs new file mode 100644 index 0000000000..10dc5e7309 --- /dev/null +++ b/crates/llm-proxy/src/types.rs @@ -0,0 +1,107 @@ +use serde::{Deserialize, Serialize}; + +pub const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions"; + +#[derive(Debug, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +impl Serialize for ChatMessage { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("role", &self.role)?; + map.serialize_entry("content", &self.content)?; + map.end() + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + String(String), + Object { + #[serde(rename = "type")] + type_: String, + function: serde_json::Value, + }, +} + +impl Serialize for ToolChoice { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + ToolChoice::String(s) => serializer.serialize_str(s), + ToolChoice::Object { type_, function } => { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("type", type_)?; + map.serialize_entry("function", function)?; + map.end() + } + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionRequest { + #[serde(default)] + #[allow(dead_code)] + pub model: Option, + pub messages: Vec, + #[serde(default)] + pub tools: Option>, + #[serde(default)] + pub tool_choice: Option, + #[serde(default)] + pub stream: Option, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub max_tokens: Option, + #[serde(flatten)] + pub extra: serde_json::Map, +} + +#[derive(Serialize)] +pub struct OpenRouterRequest { + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + pub stream: bool, + pub models: Vec, + pub provider: Provider, + #[serde(flatten)] + pub extra: serde_json::Map, +} + +#[derive(Serialize)] +pub struct Provider { + pub sort: &'static str, +} + +#[derive(Debug, Deserialize)] +pub struct OpenRouterResponse { + pub id: String, + pub model: Option, + pub usage: Option, +} + +#[derive(Debug, Deserialize)] +pub struct UsageInfo { + pub prompt_tokens: Option, + pub completion_tokens: Option, +} diff --git a/crates/transcribe-proxy/Cargo.toml b/crates/transcribe-proxy/Cargo.toml index e452835d83..c4f75dfdef 100644 --- a/crates/transcribe-proxy/Cargo.toml +++ b/crates/transcribe-proxy/Cargo.toml @@ -4,12 +4,20 @@ version = "0.1.0" edition = "2024" [dependencies] -thiserror = { workspace = true } +hypr-analytics = { workspace = true } +owhisper-providers = { workspace = true } axum = { workspace = true, features = ["ws"] } -bytes = { workspace = true } futures-util = { workspace = true } +reqwest = { workspace = true, features = ["json"] } tokio = { workspace = true, features = ["rt-multi-thread", "time", "sync", "macros"] } tokio-tungstenite = { workspace = true, features = ["native-tls-vendored"] } tower = { workspace = true } tracing = { workspace = true } +uuid = { workspace = true, features = ["v4"] } + +bytes = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +url = { workspace = true } diff --git a/crates/transcribe-proxy/src/analytics.rs b/crates/transcribe-proxy/src/analytics.rs new file mode 100644 index 0000000000..2d9a2bab30 --- /dev/null +++ b/crates/transcribe-proxy/src/analytics.rs @@ -0,0 +1,31 @@ +use std::time::Duration; + +use hypr_analytics::{AnalyticsClient, AnalyticsPayload}; + +#[derive(Debug, Clone)] +pub struct SttEvent { + pub provider: String, + pub duration: Duration, +} + +pub trait SttAnalyticsReporter: Send + Sync { + fn report_stt( + &self, + event: SttEvent, + ) -> std::pin::Pin + Send + '_>>; +} + +impl SttAnalyticsReporter for AnalyticsClient { + fn report_stt( + &self, + event: SttEvent, + ) -> std::pin::Pin + Send + '_>> { + Box::pin(async move { + let payload = AnalyticsPayload::builder("$stt_request") + .with("$stt_provider", event.provider.clone()) + .with("$stt_duration", event.duration.as_secs_f64()) + .build(); + let _ = self.event(uuid::Uuid::new_v4().to_string(), payload).await; + }) + } +} diff --git a/crates/transcribe-proxy/src/config.rs b/crates/transcribe-proxy/src/config.rs new file mode 100644 index 0000000000..f879525619 --- /dev/null +++ b/crates/transcribe-proxy/src/config.rs @@ -0,0 +1,47 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use owhisper_providers::Provider; + +use crate::analytics::SttAnalyticsReporter; + +const DEFAULT_CONNECT_TIMEOUT_MS: u64 = 5000; + +#[derive(Clone)] +pub struct SttProxyConfig { + pub api_keys: HashMap, + pub default_provider: Provider, + pub connect_timeout: Duration, + pub analytics: Option>, +} + +impl SttProxyConfig { + pub fn new(api_keys: HashMap) -> Self { + Self { + api_keys, + default_provider: Provider::Deepgram, + connect_timeout: Duration::from_millis(DEFAULT_CONNECT_TIMEOUT_MS), + analytics: None, + } + } + + pub fn with_default_provider(mut self, provider: Provider) -> Self { + self.default_provider = provider; + self + } + + pub fn with_connect_timeout(mut self, timeout: Duration) -> Self { + self.connect_timeout = timeout; + self + } + + pub fn with_analytics(mut self, analytics: Arc) -> Self { + self.analytics = Some(analytics); + self + } + + pub fn api_key_for(&self, provider: Provider) -> Option<&str> { + self.api_keys.get(&provider).map(|s| s.as_str()) + } +} diff --git a/crates/transcribe-proxy/src/lib.rs b/crates/transcribe-proxy/src/lib.rs index 2cef241e0e..80499c853a 100644 --- a/crates/transcribe-proxy/src/lib.rs +++ b/crates/transcribe-proxy/src/lib.rs @@ -1,4 +1,108 @@ +mod analytics; +mod config; mod error; +mod router; mod service; + +pub use analytics::{SttAnalyticsReporter, SttEvent}; +pub use config::*; pub use error::*; -pub use service::*; +pub use router::router; +pub use service::WebSocketProxy; + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + + use futures_util::{SinkExt, StreamExt}; + use tokio_tungstenite::connect_async; + use tokio_tungstenite::tungstenite::Message; + + use super::*; + use owhisper_providers::Provider; + + #[derive(Default, Clone)] + struct MockAnalytics { + events: Arc>>, + } + + impl SttAnalyticsReporter for MockAnalytics { + fn report_stt( + &self, + event: SttEvent, + ) -> std::pin::Pin + Send + '_>> { + let events = self.events.clone(); + Box::pin(async move { + events.lock().unwrap().push(event); + }) + } + } + + #[ignore] + #[tokio::test] + async fn e2e_deepgram_with_mock_analytics() { + let api_key = std::env::var("DEEPGRAM_API_KEY").expect("DEEPGRAM_API_KEY must be set"); + + let mock_analytics = MockAnalytics::default(); + let events = mock_analytics.events.clone(); + + let mut api_keys = HashMap::new(); + api_keys.insert(Provider::Deepgram, api_key); + + let config = SttProxyConfig::new(api_keys) + .with_default_provider(Provider::Deepgram) + .with_analytics(Arc::new(mock_analytics)); + + let app = router(config); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let url = format!( + "ws://{}/ws?encoding=linear16&sample_rate=16000&channels=1", + addr + ); + let (mut ws_stream, _) = connect_async(&url).await.expect("failed to connect"); + + let audio_data = vec![0u8; 3200]; + ws_stream + .send(Message::Binary(audio_data.into())) + .await + .expect("failed to send audio"); + + tokio::time::sleep(Duration::from_millis(500)).await; + + let close_msg = serde_json::json!({"type": "CloseStream"}); + ws_stream + .send(Message::Text(close_msg.to_string().into())) + .await + .expect("failed to send close"); + + while let Some(msg) = ws_stream.next().await { + match msg { + Ok(Message::Close(_)) => break, + Ok(_) => continue, + Err(_) => break, + } + } + + let _ = ws_stream.close(None).await; + + tokio::time::sleep(Duration::from_secs(1)).await; + + let captured_events = events.lock().unwrap(); + assert_eq!(captured_events.len(), 1); + + let event = &captured_events[0]; + assert_eq!(event.provider, "deepgram"); + assert!(event.duration.as_secs_f64() > 0.0); + } +} diff --git a/apps/stt/src/handlers.rs b/crates/transcribe-proxy/src/router.rs similarity index 62% rename from apps/stt/src/handlers.rs rename to crates/transcribe-proxy/src/router.rs index 3e427b8d9d..ec9ffb1950 100644 --- a/apps/stt/src/handlers.rs +++ b/crates/transcribe-proxy/src/router.rs @@ -1,31 +1,58 @@ use std::collections::HashMap; +use std::sync::Arc; use axum::{ - extract::{Query, WebSocketUpgrade}, + Router, + extract::{Query, State, WebSocketUpgrade}, http::StatusCode, response::{IntoResponse, Response}, + routing::any, }; +use owhisper_providers::{Auth, Provider}; use serde::{Deserialize, Serialize}; -use crate::auth::AuthUser; -use crate::env::env; -use hypr_transcribe_proxy::WebSocketProxy; -use owhisper_providers::{Auth, Provider}; +use crate::analytics::{SttAnalyticsReporter, SttEvent}; +use crate::config::SttProxyConfig; +use crate::service::WebSocketProxy; const IGNORED_PARAMS: &[&str] = &["provider", "keywords", "keyterm", "keyterms"]; -pub async fn ws_handler( - auth: AuthUser, +#[derive(Clone)] +struct AppState { + config: SttProxyConfig, + client: reqwest::Client, +} + +pub fn router(config: SttProxyConfig) -> Router { + let state = AppState { + config, + client: reqwest::Client::new(), + }; + + Router::new() + .route("/ws", any(ws_handler)) + .with_state(state) +} + +async fn ws_handler( + State(state): State, ws: WebSocketUpgrade, Query(params): Query>, ) -> Response { - tracing::info!(user_id = %auth.user_id, is_pro = %auth.is_pro(), "ws connection"); let provider = params .get("provider") .and_then(|s| s.parse::().ok()) - .unwrap_or(Provider::Deepgram); + .unwrap_or(state.config.default_provider); - let upstream_url = match resolve_upstream_url(provider, ¶ms).await { + let api_key = match state.config.api_key_for(provider) { + Some(key) => key.to_string(), + None => { + tracing::error!(provider = ?provider, "api key not configured"); + return (StatusCode::INTERNAL_SERVER_ERROR, "provider not configured").into_response(); + } + }; + + let upstream_url = match resolve_upstream_url(&state, provider, &api_key, ¶ms).await { Ok(url) => url, Err(e) => { tracing::error!(error = %e, "failed to resolve upstream url"); @@ -33,16 +60,20 @@ pub async fn ws_handler( } }; - let proxy = build_proxy(provider, &upstream_url); + let proxy = build_proxy(provider, &api_key, &upstream_url, &state.config); proxy.handle_upgrade(ws).await.into_response() } async fn resolve_upstream_url( + state: &AppState, provider: Provider, + api_key: &str, params: &HashMap, ) -> Result { match provider.auth() { - Auth::SessionInit { header_name } => init_session(provider, header_name, params).await, + Auth::SessionInit { header_name } => { + init_session(state, provider, header_name, api_key, params).await + } _ => { let mut url = url::Url::parse(&provider.default_ws_url()).unwrap(); for (key, value) in params { @@ -59,13 +90,12 @@ async fn resolve_upstream_url( } async fn init_session( + state: &AppState, provider: Provider, header_name: &'static str, + api_key: &str, params: &HashMap, ) -> Result { - let env = env(); - let api_key = env.api_key_for(provider); - let init_url = provider .default_api_url() .ok_or_else(|| format!("{:?} does not support session init", provider))?; @@ -94,10 +124,10 @@ async fn init_session( }, }; - let client = reqwest::Client::new(); - let resp = client + let resp = state + .client .post(init_url) - .header(header_name, &api_key) + .header(header_name, api_key) .header("Content-Type", "application/json") .json(&config) .send() @@ -120,26 +150,46 @@ async fn init_session( Ok(init.url) } -fn build_proxy(provider: Provider, upstream_url: &str) -> WebSocketProxy { - let env = env(); - let api_key = env.api_key_for(provider); - - let mut builder = WebSocketProxy::builder().upstream_url(upstream_url); +fn build_proxy( + provider: Provider, + api_key: &str, + upstream_url: &str, + config: &SttProxyConfig, +) -> WebSocketProxy { + let mut builder = WebSocketProxy::builder() + .upstream_url(upstream_url) + .connect_timeout(config.connect_timeout); match provider.auth() { Auth::Header { .. } => { - if let Some((name, value)) = provider.build_auth_header(&api_key) { + if let Some((name, value)) = provider.build_auth_header(api_key) { builder = builder.header(name, value); } } Auth::FirstMessage { .. } => { let auth = provider.auth(); + let api_key = api_key.to_string(); builder = builder .transform_first_message(move |msg| auth.transform_first_message(msg, &api_key)); } Auth::SessionInit { .. } => {} } + if let Some(analytics) = config.analytics.clone() { + let provider_name = format!("{:?}", provider).to_lowercase(); + builder = builder.on_close(move |duration| { + let analytics: Arc = analytics.clone(); + let provider_name = provider_name.clone(); + tokio::spawn(async move { + let event = SttEvent { + provider: provider_name, + duration, + }; + analytics.report_stt(event).await; + }); + }); + } + builder.build() } diff --git a/crates/transcribe-proxy/src/service.rs b/crates/transcribe-proxy/src/service.rs index 64d0b099c1..a3dc15c8c1 100644 --- a/crates/transcribe-proxy/src/service.rs +++ b/crates/transcribe-proxy/src/service.rs @@ -3,7 +3,7 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::Duration; +use std::time::{Duration, Instant}; use axum::body::Body; use axum::extract::ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade}; @@ -26,6 +26,8 @@ const DEFAULT_CLOSE_CODE: u16 = 1011; const UPSTREAM_CONNECT_TIMEOUT_MS: u64 = 5000; const MAX_PENDING_QUEUE_BYTES: usize = 5 * 1024 * 1024; // 5 MiB +type OnCloseCallback = Arc; + #[derive(Debug, Clone)] struct QueuedPayload { data: Vec, @@ -33,6 +35,13 @@ struct QueuedPayload { size: usize, } +#[derive(Clone)] +struct PendingState { + control_messages: Arc>>, + data_messages: Arc>>, + bytes: Arc>, +} + type ControlMessageMatcher = Arc bool + Send + Sync>; type FirstMessageTransformer = Arc String + Send + Sync>; type UpstreamSender = SplitSink< @@ -49,6 +58,7 @@ pub struct WebSocketProxyBuilder { control_message_matcher: Option, transform_first_message: Option, connect_timeout: Duration, + on_close: Option, } impl Default for WebSocketProxyBuilder { @@ -59,6 +69,7 @@ impl Default for WebSocketProxyBuilder { control_message_matcher: None, transform_first_message: None, connect_timeout: Duration::from_millis(UPSTREAM_CONNECT_TIMEOUT_MS), + on_close: None, } } } @@ -88,6 +99,7 @@ impl WebSocketProxyBuilder { control_message_matcher: self.control_message_matcher, transform_first_message: self.transform_first_message, connect_timeout: self.connect_timeout, + on_close: self.on_close, } } @@ -112,6 +124,14 @@ impl WebSocketProxyBuilder { self } + pub fn on_close(mut self, callback: F) -> Self + where + F: Fn(Duration) + Send + Sync + 'static, + { + self.on_close = Some(Arc::new(callback)); + self + } + pub fn build(self) -> WebSocketProxy { let url = self.upstream_url.expect("upstream_url is required"); let mut request = ClientRequestBuilder::new(url.parse().expect("invalid upstream URL")); @@ -125,6 +145,7 @@ impl WebSocketProxyBuilder { control_message_matcher: self.control_message_matcher, transform_first_message: self.transform_first_message, connect_timeout: self.connect_timeout, + on_close: self.on_close, } } } @@ -134,6 +155,7 @@ pub struct WebSocketProxyBuilderWithRequest { control_message_matcher: Option, transform_first_message: Option, connect_timeout: Duration, + on_close: Option, } impl WebSocketProxyBuilderWithRequest { @@ -158,12 +180,21 @@ impl WebSocketProxyBuilderWithRequest { self } + pub fn on_close(mut self, callback: F) -> Self + where + F: Fn(Duration) + Send + Sync + 'static, + { + self.on_close = Some(Arc::new(callback)); + self + } + pub fn build(self) -> WebSocketProxy { WebSocketProxy { upstream_request: self.upstream_request, control_message_matcher: self.control_message_matcher, transform_first_message: self.transform_first_message, connect_timeout: self.connect_timeout, + on_close: self.on_close, } } } @@ -174,6 +205,7 @@ pub struct WebSocketProxy { control_message_matcher: Option, transform_first_message: Option, connect_timeout: Duration, + on_close: Option, } impl WebSocketProxy { @@ -187,6 +219,7 @@ impl WebSocketProxy { self.control_message_matcher.clone(), self.transform_first_message.clone(), self.connect_timeout, + self.on_close.clone(), ); connection.run(client_socket).await; } @@ -224,6 +257,7 @@ impl WebSocketProxy { upstream_stream: Some(upstream_stream), control_message_matcher: self.control_message_matcher.clone(), transform_first_message: self.transform_first_message.clone(), + on_close: self.on_close.clone(), }) } } @@ -266,6 +300,7 @@ pub struct PreconnectedProxy { upstream_stream: Option>>, control_message_matcher: Option, transform_first_message: Option, + on_close: Option, } impl PreconnectedProxy { @@ -283,6 +318,7 @@ impl PreconnectedProxy { upstream_stream, self.control_message_matcher, self.transform_first_message, + self.on_close, ) .await; } @@ -300,6 +336,7 @@ struct WebSocketProxyConnection { control_message_matcher: Option, transform_first_message: Option, connect_timeout: Duration, + on_close: Option, } impl WebSocketProxyConnection { @@ -308,12 +345,14 @@ impl WebSocketProxyConnection { control_message_matcher: Option, transform_first_message: Option, connect_timeout: Duration, + on_close: Option, ) -> Self { Self { upstream_request, control_message_matcher, transform_first_message, connect_timeout, + on_close, } } @@ -330,6 +369,8 @@ impl WebSocketProxyConnection { } async fn run(self, client_socket: WebSocket) { + let start_time = Instant::now(); + let req = match self.upstream_request.into_client_request() { Ok(r) => r, Err(e) => { @@ -357,11 +398,11 @@ impl WebSocketProxyConnection { let (upstream_sender, upstream_receiver) = upstream_stream.split(); let (client_sender, client_receiver) = client_socket.split(); - let pending_control_messages: Arc>> = - Arc::new(Mutex::new(Vec::new())); - let pending_data_messages: Arc>> = - Arc::new(Mutex::new(Vec::new())); - let pending_bytes: Arc> = Arc::new(Mutex::new(0)); + let pending_state = PendingState { + control_messages: Arc::new(Mutex::new(Vec::new())), + data_messages: Arc::new(Mutex::new(Vec::new())), + bytes: Arc::new(Mutex::new(0)), + }; let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<(u16, String)>(1); let shutdown_rx2 = shutdown_tx.subscribe(); @@ -376,9 +417,7 @@ impl WebSocketProxyConnection { shutdown_rx, control_matcher, first_msg_transformer, - pending_control_messages, - pending_data_messages, - pending_bytes, + pending_state, ); let upstream_to_client = Self::run_upstream_to_client( @@ -389,6 +428,11 @@ impl WebSocketProxyConnection { ); let _ = tokio::join!(client_to_upstream, upstream_to_client); + + if let Some(on_close) = self.on_close { + on_close(start_time.elapsed()); + } + tracing::info!("websocket_proxy_connection_closed"); } @@ -397,15 +441,18 @@ impl WebSocketProxyConnection { upstream_stream: WebSocketStream>, control_message_matcher: Option, transform_first_message: Option, + on_close: Option, ) { + let start_time = Instant::now(); + let (upstream_sender, upstream_receiver) = upstream_stream.split(); let (client_sender, client_receiver) = client_socket.split(); - let pending_control_messages: Arc>> = - Arc::new(Mutex::new(Vec::new())); - let pending_data_messages: Arc>> = - Arc::new(Mutex::new(Vec::new())); - let pending_bytes: Arc> = Arc::new(Mutex::new(0)); + let pending_state = PendingState { + control_messages: Arc::new(Mutex::new(Vec::new())), + data_messages: Arc::new(Mutex::new(Vec::new())), + bytes: Arc::new(Mutex::new(0)), + }; let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<(u16, String)>(1); let shutdown_rx2 = shutdown_tx.subscribe(); @@ -417,9 +464,7 @@ impl WebSocketProxyConnection { shutdown_rx, control_message_matcher, transform_first_message, - pending_control_messages, - pending_data_messages, - pending_bytes, + pending_state, ); let upstream_to_client = Self::run_upstream_to_client( @@ -430,6 +475,11 @@ impl WebSocketProxyConnection { ); let _ = tokio::join!(client_to_upstream, upstream_to_client); + + if let Some(on_close) = on_close { + on_close(start_time.elapsed()); + } + tracing::info!("websocket_proxy_connection_closed"); } @@ -440,10 +490,11 @@ impl WebSocketProxyConnection { mut shutdown_rx: tokio::sync::broadcast::Receiver<(u16, String)>, control_matcher: Option, first_msg_transformer: Option, - pending_control_messages: Arc>>, - pending_data_messages: Arc>>, - pending_bytes: Arc>, + pending_state: PendingState, ) { + let pending_control_messages = pending_state.control_messages; + let pending_data_messages = pending_state.data_messages; + let pending_bytes = pending_state.bytes; let mut has_transformed_first = first_msg_transformer.is_none(); loop { diff --git a/doxxer.ai.toml b/doxxer.ai.toml new file mode 100644 index 0000000000..ad1f78e3b8 --- /dev/null +++ b/doxxer.ai.toml @@ -0,0 +1,13 @@ +filter.tag = "^ai_v" + +[output] +format = "plain" + +[next.patch] +increment = 1 + +[next.minor] +increment = 1 + +[next.major] +increment = 1