diff --git a/Cargo.lock b/Cargo.lock index d872a5231e..c681d2f79f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1426,7 +1426,7 @@ dependencies = [ "hmac", "http 0.2.12", "http 1.4.0", - "p256", + "p256 0.11.1", "percent-encoding", "ring", "sha2", @@ -1847,6 +1847,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "base64" version = "0.13.1" @@ -3821,8 +3827,10 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" dependencies = [ + "generic-array", "rand_core 0.6.4", "subtle", + "zeroize", ] [[package]] @@ -3920,6 +3928,33 @@ dependencies = [ "libloading 0.8.9", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "digest", + "fiat-crypto", + "rustc_version 0.4.1", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "dagc" version = "0.1.1" @@ -4409,6 +4444,7 @@ version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" dependencies = [ + "const-oid", "pem-rfc7468", "zeroize", ] @@ -4611,6 +4647,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -4777,6 +4814,12 @@ dependencies = [ "litrs", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "downcast-rs" version = "1.2.1" @@ -4945,9 +4988,47 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "413301934810f597c1d19ca71c8710e99a3f1ba28a0d2ebc01551a2daeea3c5c" dependencies = [ "der 0.6.1", - "elliptic-curve", - "rfc6979", - "signature", + "elliptic-curve 0.12.3", + "rfc6979 0.3.1", + "signature 1.6.4", +] + +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der 0.7.10", + "digest", + "elliptic-curve 0.13.8", + "rfc6979 0.4.0", + "signature 2.2.0", + "spki 0.7.3", +] + +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "pkcs8 0.10.2", + "signature 2.2.0", +] + +[[package]] +name = "ed25519-dalek" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" +dependencies = [ + "curve25519-dalek", + "ed25519", + "serde", + "sha2", + "subtle", + "zeroize", ] [[package]] @@ -4962,16 +5043,37 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7bb888ab5300a19b8e5bceef25ac745ad065f3c9f7efc6de1b91958110891d3" dependencies = [ - "base16ct", + "base16ct 0.1.1", "crypto-bigint 0.4.9", "der 0.6.1", "digest", - "ff", + "ff 0.12.1", + "generic-array", + "group 0.12.1", + "pkcs8 0.9.0", + "rand_core 0.6.4", + "sec1 0.3.0", + "subtle", + "zeroize", +] + +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct 0.2.0", + "crypto-bigint 0.5.5", + "digest", + "ff 0.13.1", "generic-array", - "group", - "pkcs8", + "group 0.13.0", + "hkdf", + "pem-rfc7468", + "pkcs8 0.10.2", "rand_core 0.6.4", - "sec1", + "sec1 0.7.3", "subtle", "zeroize", ] @@ -5371,6 +5473,22 @@ dependencies = [ "subtle", ] +[[package]] +name = "ff" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "field-offset" version = "0.3.6" @@ -6207,6 +6325,7 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", + "zeroize", ] [[package]] @@ -7835,7 +7954,18 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" dependencies = [ - "ff", + "ff 0.12.1", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff 0.13.1", "rand_core 0.6.4", "subtle", ] @@ -9195,6 +9325,29 @@ dependencies = [ "serde_json", ] +[[package]] +name = "jsonwebtoken" +version = "10.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c76e1c7d7df3e34443b3621b459b066a7b79644f059fc8b2db7070c825fd417e" +dependencies = [ + "base64 0.22.1", + "ed25519-dalek", + "getrandom 0.2.16", + "hmac", + "js-sys", + "p256 0.13.2", + "p384", + "pem", + "rand 0.8.5", + "rsa", + "serde", + "serde_json", + "sha2", + "signature 2.2.0", + "simple_asn1", +] + [[package]] name = "keyboard-types" version = "0.7.0" @@ -11821,6 +11974,7 @@ dependencies = [ "futures-util", "language", "owhisper-interface", + "owhisper-providers", "reqwest", "reqwest-middleware", "reqwest-tracing", @@ -11876,6 +12030,15 @@ dependencies = [ "whisper-local-model", ] +[[package]] +name = "owhisper-providers" +version = "0.0.1" +dependencies = [ + "serde_json", + "strum 0.26.3", + "url", +] + [[package]] name = "owhisper-server" version = "0.0.3" @@ -11934,8 +12097,32 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" dependencies = [ - "ecdsa", - "elliptic-curve", + "ecdsa 0.14.8", + "elliptic-curve 0.12.3", + "sha2", +] + +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa 0.16.9", + "elliptic-curve 0.13.8", + "primeorder", + "sha2", +] + +[[package]] +name = "p384" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" +dependencies = [ + "ecdsa 0.16.9", + "elliptic-curve 0.13.8", + "primeorder", "sha2", ] @@ -12114,6 +12301,16 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64 0.22.1", + "serde_core", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -12447,6 +12644,17 @@ dependencies = [ "image 0.23.14", ] +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der 0.7.10", + "pkcs8 0.10.2", + "spki 0.7.3", +] + [[package]] name = "pkcs8" version = "0.9.0" @@ -12454,7 +12662,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba" dependencies = [ "der 0.6.1", - "spki", + "spki 0.6.0", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der 0.7.10", + "spki 0.7.3", ] [[package]] @@ -12703,6 +12921,15 @@ dependencies = [ "num-integer", ] +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve 0.13.8", +] + [[package]] name = "proc-macro-crate" version = "1.3.1" @@ -13732,6 +13959,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + [[package]] name = "rfd" version = "0.15.4" @@ -13854,6 +14091,26 @@ dependencies = [ "cc", ] +[[package]] +name = "rsa" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40a0376c50d0358279d9d643e4bf7b7be212f1f4ff1da9070a7b54d22ef75c88" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8 0.10.2", + "rand_core 0.6.4", + "signature 2.2.0", + "spki 0.7.3", + "subtle", + "zeroize", +] + [[package]] name = "rubato" version = "0.16.2" @@ -14371,10 +14628,24 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3be24c1842290c45df0a7bf069e0c268a747ad05a192f2fd7dcfdbc1cba40928" dependencies = [ - "base16ct", + "base16ct 0.1.1", "der 0.6.1", "generic-array", - "pkcs8", + "pkcs8 0.9.0", + "subtle", + "zeroize", +] + +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct 0.2.0", + "der 0.7.10", + "generic-array", + "pkcs8 0.10.2", "subtle", "zeroize", ] @@ -15123,6 +15394,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + [[package]] name = "silero" version = "0.1.0" @@ -15172,6 +15453,18 @@ version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" +[[package]] +name = "simple_asn1" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.17", + "time", +] + [[package]] name = "simplecss" version = "0.2.2" @@ -15486,6 +15779,16 @@ dependencies = [ "der 0.6.1", ] +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der 0.7.10", +] + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -15733,12 +16036,19 @@ name = "stt-server" version = "0.1.0" dependencies = [ "axum 0.8.7", + "dotenvy", + "jsonwebtoken", + "owhisper-providers", + "reqwest", "sentry", + "serde", + "serde_json", "tokio", "tower-http 0.6.8", "tracing", "tracing-subscriber", "transcribe-proxy", + "url", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 9d3d3b49d3..3843fbe3cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,6 +89,7 @@ owhisper-client = { path = "owhisper/owhisper-client", package = "owhisper-clien owhisper-config = { path = "owhisper/owhisper-config", package = "owhisper-config" } owhisper-interface = { path = "owhisper/owhisper-interface", package = "owhisper-interface" } owhisper-model = { path = "owhisper/owhisper-model", package = "owhisper-model" } +owhisper-providers = { path = "owhisper/owhisper-providers", package = "owhisper-providers" } tauri = "2.9" tauri-build = "2.5" @@ -165,7 +166,7 @@ clap = "4" codes-iso-639 = "0.1.5" derive_more = "2" dirs = "6.0.0" -dotenv = "0.15.0" +dotenvy = "0.15.7" include_url_macro = "0.1.0" indoc = "2" itertools = "0.14.0" @@ -205,6 +206,7 @@ async-openai = { git = "https://github.com/fastrepl/async-openai", rev = "6404d3 async-stripe = { version = "0.39.1", default-features = false } gbnf-validator = { git = "https://github.com/fastrepl/gbnf-validator", rev = "3dec055" } +jsonwebtoken = { version = "10", features = ["rust_crypto"] } sentry = "0.42" vergen-gix = "1" diff --git a/apps/stt/Cargo.toml b/apps/stt/Cargo.toml index 0eec21562f..1e1c63ba8b 100644 --- a/apps/stt/Cargo.toml +++ b/apps/stt/Cargo.toml @@ -9,6 +9,7 @@ path = "src/main.rs" [dependencies] hypr-transcribe-proxy = { workspace = true } +owhisper-providers = { workspace = true } axum = { workspace = true, features = ["ws"] } tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] } @@ -16,4 +17,11 @@ tower-http = { workspace = true, features = ["trace"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } +reqwest = { workspace = true, features = ["json"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +url = { workspace = true } + +dotenvy = { workspace = true } +jsonwebtoken = { workspace = true } sentry = { workspace = true } diff --git a/apps/stt/src/auth.rs b/apps/stt/src/auth.rs new file mode 100644 index 0000000000..ecd6e20952 --- /dev/null +++ b/apps/stt/src/auth.rs @@ -0,0 +1,136 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use axum::{ + extract::FromRequestParts, + http::{StatusCode, request::Parts}, +}; +use jsonwebtoken::{DecodingKey, Validation, decode, decode_header, jwk::JwkSet}; +use serde::Deserialize; +use tokio::sync::RwLock; + +use crate::env::env; + +const ENTITLEMENT_PRO: &str = "hyprnote_pro"; +const JWKS_CACHE_TTL: Duration = Duration::from_secs(5 * 60); + +#[derive(Debug, Clone)] +pub struct AuthUser { + pub user_id: String, + pub entitlements: Vec, +} + +impl AuthUser { + pub fn is_pro(&self) -> bool { + self.entitlements.iter().any(|e| e == ENTITLEMENT_PRO) + } +} + +#[derive(Debug, Deserialize)] +struct Claims { + sub: String, + #[serde(default)] + entitlements: Vec, +} + +struct CachedJwks { + jwks: JwkSet, + fetched_at: Instant, +} + +static JWKS_CACHE: std::sync::OnceLock>>> = + std::sync::OnceLock::new(); + +fn jwks_cache() -> &'static Arc>> { + JWKS_CACHE.get_or_init(|| Arc::new(RwLock::new(None))) +} + +async fn get_jwks() -> Result { + let cache = jwks_cache(); + + { + let guard = cache.read().await; + if let Some(cached) = guard.as_ref() { + if cached.fetched_at.elapsed() < JWKS_CACHE_TTL { + return Ok(cached.jwks.clone()); + } + } + } + + let env = env(); + let jwks_url = format!("{}/auth/v1/.well-known/jwks.json", env.supabase_url); + + let jwks: JwkSet = reqwest::get(&jwks_url) + .await + .map_err(|_| "failed to fetch jwks")? + .json() + .await + .map_err(|_| "failed to parse jwks")?; + + { + let mut guard = cache.write().await; + *guard = Some(CachedJwks { + jwks: jwks.clone(), + fetched_at: Instant::now(), + }); + } + + Ok(jwks) +} + +impl FromRequestParts for AuthUser +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let auth_header = parts + .headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .ok_or((StatusCode::UNAUTHORIZED, "missing authorization header"))?; + + let token = auth_header + .strip_prefix("Bearer ") + .or_else(|| auth_header.strip_prefix("bearer ")) + .ok_or((StatusCode::UNAUTHORIZED, "invalid authorization header"))?; + + let header = + decode_header(token).map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token header"))?; + + let kid = header + .kid + .as_ref() + .ok_or((StatusCode::UNAUTHORIZED, "missing kid in token"))?; + + let jwks = get_jwks() + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + + let jwk = jwks + .find(kid) + .ok_or((StatusCode::UNAUTHORIZED, "unknown signing key"))?; + + let key = DecodingKey::from_jwk(jwk) + .map_err(|_| (StatusCode::UNAUTHORIZED, "invalid signing key"))?; + + let alg = jwk + .common + .key_algorithm + .and_then(|a| a.to_string().parse().ok()) + .ok_or((StatusCode::UNAUTHORIZED, "unsupported algorithm"))?; + + let mut validation = Validation::new(alg); + validation.set_audience(&["authenticated"]); + validation.validate_exp = true; + + let token_data = decode::(token, &key, &validation) + .map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token"))?; + + Ok(AuthUser { + user_id: token_data.claims.sub, + entitlements: token_data.claims.entitlements, + }) + } +} diff --git a/apps/stt/src/env.rs b/apps/stt/src/env.rs new file mode 100644 index 0000000000..4c376c0d3d --- /dev/null +++ b/apps/stt/src/env.rs @@ -0,0 +1,66 @@ +use std::collections::HashMap; +use std::sync::OnceLock; + +use owhisper_providers::Provider; + +pub struct Env { + pub port: u16, + pub sentry_dsn: Option, + pub supabase_url: String, + api_keys: HashMap, +} + +static ENV: OnceLock = OnceLock::new(); + +pub fn env() -> &'static Env { + ENV.get_or_init(|| { + let _ = dotenvy::dotenv(); + Env::from_env() + }) +} + +impl Env { + fn from_env() -> Self { + let providers = [ + Provider::Deepgram, + Provider::AssemblyAI, + Provider::Soniox, + Provider::Fireworks, + Provider::OpenAI, + Provider::Gladia, + ]; + let api_keys = providers + .into_iter() + .map(|p| (p, required(p.env_key_name()))) + .collect(); + + Self { + port: parse_or("PORT", 3000), + sentry_dsn: optional("SENTRY_DSN"), + supabase_url: required("SUPABASE_URL"), + api_keys, + } + } + + pub fn api_key_for(&self, provider: Provider) -> String { + self.api_keys + .get(&provider) + .cloned() + .unwrap_or_else(|| panic!("{} is not configured", provider.env_key_name())) + } +} + +fn required(key: &str) -> String { + std::env::var(key).unwrap_or_else(|_| panic!("{key} is required")) +} + +fn optional(key: &str) -> Option { + std::env::var(key).ok().filter(|s| !s.is_empty()) +} + +fn parse_or(key: &str, default: T) -> T { + std::env::var(key) + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(default) +} diff --git a/apps/stt/src/handlers.rs b/apps/stt/src/handlers.rs new file mode 100644 index 0000000000..3e427b8d9d --- /dev/null +++ b/apps/stt/src/handlers.rs @@ -0,0 +1,171 @@ +use std::collections::HashMap; + +use axum::{ + extract::{Query, WebSocketUpgrade}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use serde::{Deserialize, Serialize}; + +use crate::auth::AuthUser; +use crate::env::env; +use hypr_transcribe_proxy::WebSocketProxy; +use owhisper_providers::{Auth, Provider}; + +const IGNORED_PARAMS: &[&str] = &["provider", "keywords", "keyterm", "keyterms"]; + +pub async fn ws_handler( + auth: AuthUser, + ws: WebSocketUpgrade, + Query(params): Query>, +) -> Response { + tracing::info!(user_id = %auth.user_id, is_pro = %auth.is_pro(), "ws connection"); + let provider = params + .get("provider") + .and_then(|s| s.parse::().ok()) + .unwrap_or(Provider::Deepgram); + + let upstream_url = match resolve_upstream_url(provider, ¶ms).await { + Ok(url) => url, + Err(e) => { + tracing::error!(error = %e, "failed to resolve upstream url"); + return (StatusCode::BAD_GATEWAY, e).into_response(); + } + }; + + let proxy = build_proxy(provider, &upstream_url); + proxy.handle_upgrade(ws).await.into_response() +} + +async fn resolve_upstream_url( + provider: Provider, + params: &HashMap, +) -> Result { + match provider.auth() { + Auth::SessionInit { header_name } => init_session(provider, header_name, params).await, + _ => { + let mut url = url::Url::parse(&provider.default_ws_url()).unwrap(); + for (key, value) in params { + if !IGNORED_PARAMS.contains(&key.as_str()) { + url.query_pairs_mut().append_pair(key, value); + } + } + for (key, value) in provider.default_query_params() { + url.query_pairs_mut().append_pair(key, value); + } + Ok(url.to_string()) + } + } +} + +async fn init_session( + provider: Provider, + header_name: &'static str, + params: &HashMap, +) -> Result { + let env = env(); + let api_key = env.api_key_for(provider); + + let init_url = provider + .default_api_url() + .ok_or_else(|| format!("{:?} does not support session init", provider))?; + + let sample_rate: u32 = params + .get("sample_rate") + .and_then(|s| s.parse().ok()) + .unwrap_or(16000); + + let channels: u8 = params + .get("channels") + .and_then(|s| s.parse().ok()) + .unwrap_or(1); + + let config = GladiaConfig { + encoding: "wav/pcm", + sample_rate, + bit_depth: 16, + channels, + messages_config: MessagesConfig { + receive_partial_transcripts: true, + receive_final_transcripts: true, + }, + realtime_processing: RealtimeProcessing { + words_accurate_timestamps: true, + }, + }; + + let client = reqwest::Client::new(); + let resp = client + .post(init_url) + .header(header_name, &api_key) + .header("Content-Type", "application/json") + .json(&config) + .send() + .await + .map_err(|e| format!("session init request failed: {}", e))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(format!("session init failed: {} - {}", status, body)); + } + + let init: InitResponse = resp + .json() + .await + .map_err(|e| format!("session init parse failed: {}", e))?; + + tracing::debug!(session_id = %init.id, url = %init.url, "session_initialized"); + + Ok(init.url) +} + +fn build_proxy(provider: Provider, upstream_url: &str) -> WebSocketProxy { + let env = env(); + let api_key = env.api_key_for(provider); + + let mut builder = WebSocketProxy::builder().upstream_url(upstream_url); + + match provider.auth() { + Auth::Header { .. } => { + if let Some((name, value)) = provider.build_auth_header(&api_key) { + builder = builder.header(name, value); + } + } + Auth::FirstMessage { .. } => { + let auth = provider.auth(); + builder = builder + .transform_first_message(move |msg| auth.transform_first_message(msg, &api_key)); + } + Auth::SessionInit { .. } => {} + } + + builder.build() +} + +#[derive(Serialize)] +struct GladiaConfig<'a> { + encoding: &'a str, + sample_rate: u32, + bit_depth: u8, + channels: u8, + messages_config: MessagesConfig, + realtime_processing: RealtimeProcessing, +} + +#[derive(Serialize)] +struct MessagesConfig { + receive_partial_transcripts: bool, + receive_final_transcripts: bool, +} + +#[derive(Serialize)] +struct RealtimeProcessing { + words_accurate_timestamps: bool, +} + +#[derive(Deserialize)] +struct InitResponse { + id: String, + url: String, +} diff --git a/apps/stt/src/main.rs b/apps/stt/src/main.rs index 2e6efe35a5..5e5dc1353e 100644 --- a/apps/stt/src/main.rs +++ b/apps/stt/src/main.rs @@ -1,23 +1,21 @@ +mod auth; +mod env; +mod handlers; + use std::net::SocketAddr; -use axum::{Router, extract::WebSocketUpgrade, response::IntoResponse, routing::any}; -use hypr_transcribe_proxy::WebSocketProxy; +use axum::{Router, routing::any}; use tower_http::trace::TraceLayer; +use env::env; +use handlers::ws_handler; + fn app() -> Router { Router::new() - .route("/ws", any(ws_handler)) + .route("/listen", any(ws_handler)) .layer(TraceLayer::new_for_http()) } -async fn ws_handler(ws: WebSocketUpgrade) -> impl IntoResponse { - let proxy = WebSocketProxy::builder() - .upstream_url("wss://example.com") - .build(); - - proxy.handle_upgrade(ws).await -} - #[tokio::main] async fn main() { tracing_subscriber::fmt() @@ -27,14 +25,14 @@ async fn main() { ) .init(); + let env = env(); + let _guard = sentry::init(sentry::ClientOptions { - dsn: std::env::var("SENTRY_DSN") - .ok() - .and_then(|s| s.parse().ok()), + dsn: env.sentry_dsn.as_ref().and_then(|s| s.parse().ok()), ..Default::default() }); - let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); + let addr = SocketAddr::from(([0, 0, 0, 0], env.port)); tracing::info!("listening on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); diff --git a/crates/transcribe-proxy/src/service.rs b/crates/transcribe-proxy/src/service.rs index 3c3a19e857..64d0b099c1 100644 --- a/crates/transcribe-proxy/src/service.rs +++ b/crates/transcribe-proxy/src/service.rs @@ -34,6 +34,7 @@ struct QueuedPayload { } type ControlMessageMatcher = Arc bool + Send + Sync>; +type FirstMessageTransformer = Arc String + Send + Sync>; type UpstreamSender = SplitSink< WebSocketStream>, tokio_tungstenite::tungstenite::Message, @@ -46,6 +47,7 @@ pub struct WebSocketProxyBuilder { upstream_url: Option, headers: HashMap, control_message_matcher: Option, + transform_first_message: Option, connect_timeout: Duration, } @@ -55,6 +57,7 @@ impl Default for WebSocketProxyBuilder { upstream_url: None, headers: HashMap::new(), control_message_matcher: None, + transform_first_message: None, connect_timeout: Duration::from_millis(UPSTREAM_CONNECT_TIMEOUT_MS), } } @@ -83,6 +86,7 @@ impl WebSocketProxyBuilder { WebSocketProxyBuilderWithRequest { upstream_request: request, control_message_matcher: self.control_message_matcher, + transform_first_message: self.transform_first_message, connect_timeout: self.connect_timeout, } } @@ -95,6 +99,14 @@ impl WebSocketProxyBuilder { self } + pub fn transform_first_message(mut self, transformer: F) -> Self + where + F: Fn(String) -> String + Send + Sync + 'static, + { + self.transform_first_message = Some(Arc::new(transformer)); + self + } + pub fn connect_timeout(mut self, timeout: Duration) -> Self { self.connect_timeout = timeout; self @@ -111,6 +123,7 @@ impl WebSocketProxyBuilder { WebSocketProxy { upstream_request: request, control_message_matcher: self.control_message_matcher, + transform_first_message: self.transform_first_message, connect_timeout: self.connect_timeout, } } @@ -119,6 +132,7 @@ impl WebSocketProxyBuilder { pub struct WebSocketProxyBuilderWithRequest { upstream_request: ClientRequestBuilder, control_message_matcher: Option, + transform_first_message: Option, connect_timeout: Duration, } @@ -131,6 +145,14 @@ impl WebSocketProxyBuilderWithRequest { self } + pub fn transform_first_message(mut self, transformer: F) -> Self + where + F: Fn(String) -> String + Send + Sync + 'static, + { + self.transform_first_message = Some(Arc::new(transformer)); + self + } + pub fn connect_timeout(mut self, timeout: Duration) -> Self { self.connect_timeout = timeout; self @@ -140,6 +162,7 @@ impl WebSocketProxyBuilderWithRequest { WebSocketProxy { upstream_request: self.upstream_request, control_message_matcher: self.control_message_matcher, + transform_first_message: self.transform_first_message, connect_timeout: self.connect_timeout, } } @@ -149,6 +172,7 @@ impl WebSocketProxyBuilderWithRequest { pub struct WebSocketProxy { upstream_request: ClientRequestBuilder, control_message_matcher: Option, + transform_first_message: Option, connect_timeout: Duration, } @@ -161,6 +185,7 @@ impl WebSocketProxy { let connection = WebSocketProxyConnection::new( self.upstream_request.clone(), self.control_message_matcher.clone(), + self.transform_first_message.clone(), self.connect_timeout, ); connection.run(client_socket).await; @@ -198,6 +223,7 @@ impl WebSocketProxy { Ok(PreconnectedProxy { upstream_stream: Some(upstream_stream), control_message_matcher: self.control_message_matcher.clone(), + transform_first_message: self.transform_first_message.clone(), }) } } @@ -239,6 +265,7 @@ impl Service> for WebSocketProxy { pub struct PreconnectedProxy { upstream_stream: Option>>, control_message_matcher: Option, + transform_first_message: Option, } impl PreconnectedProxy { @@ -255,6 +282,7 @@ impl PreconnectedProxy { client_socket, upstream_stream, self.control_message_matcher, + self.transform_first_message, ) .await; } @@ -270,6 +298,7 @@ impl PreconnectedProxy { struct WebSocketProxyConnection { upstream_request: ClientRequestBuilder, control_message_matcher: Option, + transform_first_message: Option, connect_timeout: Duration, } @@ -277,11 +306,13 @@ impl WebSocketProxyConnection { fn new( upstream_request: ClientRequestBuilder, control_message_matcher: Option, + transform_first_message: Option, connect_timeout: Duration, ) -> Self { Self { upstream_request, control_message_matcher, + transform_first_message, connect_timeout, } } @@ -336,6 +367,7 @@ impl WebSocketProxyConnection { let shutdown_rx2 = shutdown_tx.subscribe(); let control_matcher = self.control_message_matcher.clone(); + let first_msg_transformer = self.transform_first_message.clone(); let client_to_upstream = Self::run_client_to_upstream( client_receiver, @@ -343,6 +375,7 @@ impl WebSocketProxyConnection { shutdown_tx.clone(), shutdown_rx, control_matcher, + first_msg_transformer, pending_control_messages, pending_data_messages, pending_bytes, @@ -363,6 +396,7 @@ impl WebSocketProxyConnection { client_socket: WebSocket, upstream_stream: WebSocketStream>, control_message_matcher: Option, + transform_first_message: Option, ) { let (upstream_sender, upstream_receiver) = upstream_stream.split(); let (client_sender, client_receiver) = client_socket.split(); @@ -382,6 +416,7 @@ impl WebSocketProxyConnection { shutdown_tx.clone(), shutdown_rx, control_message_matcher, + transform_first_message, pending_control_messages, pending_data_messages, pending_bytes, @@ -404,10 +439,13 @@ impl WebSocketProxyConnection { shutdown_tx: tokio::sync::broadcast::Sender<(u16, String)>, mut shutdown_rx: tokio::sync::broadcast::Receiver<(u16, String)>, control_matcher: Option, + first_msg_transformer: Option, pending_control_messages: Arc>>, pending_data_messages: Arc>>, pending_bytes: Arc>, ) { + let mut has_transformed_first = first_msg_transformer.is_none(); + loop { tokio::select! { biased; @@ -440,6 +478,16 @@ impl WebSocketProxyConnection { match msg { Message::Text(text) => { + let text = if !has_transformed_first { + has_transformed_first = true; + if let Some(ref transformer) = first_msg_transformer { + transformer(text.to_string()) + } else { + text.to_string() + } + } else { + text.to_string() + }; let data = text.as_bytes().to_vec(); let size = Self::get_payload_size(&data); diff --git a/owhisper/owhisper-client/Cargo.toml b/owhisper/owhisper-client/Cargo.toml index 57a758e09a..38d00a4e23 100644 --- a/owhisper/owhisper-client/Cargo.toml +++ b/owhisper/owhisper-client/Cargo.toml @@ -10,6 +10,7 @@ hypr-language = { workspace = true } hypr-ws-client = { workspace = true } owhisper-interface = { workspace = true } +owhisper-providers = { workspace = true } futures-util = { workspace = true } reqwest = { workspace = true, features = ["json", "multipart"] } diff --git a/owhisper/owhisper-client/src/adapter/assemblyai/live.rs b/owhisper/owhisper-client/src/adapter/assemblyai/live.rs index 0cc841b6c0..304fb23fc2 100644 --- a/owhisper/owhisper-client/src/adapter/assemblyai/live.rs +++ b/owhisper/owhisper-client/src/adapter/assemblyai/live.rs @@ -62,7 +62,7 @@ impl RealtimeSttAdapter for AssemblyAIAdapter { } fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { - api_key.map(|key| ("Authorization", key.to_string())) + api_key.and_then(|k| owhisper_providers::Provider::AssemblyAI.build_auth_header(k)) } fn keep_alive_message(&self) -> Option { diff --git a/owhisper/owhisper-client/src/adapter/assemblyai/mod.rs b/owhisper/owhisper-client/src/adapter/assemblyai/mod.rs index 3052f530c1..0efc37ef7d 100644 --- a/owhisper/owhisper-client/src/adapter/assemblyai/mod.rs +++ b/owhisper/owhisper-client/src/adapter/assemblyai/mod.rs @@ -8,19 +8,16 @@ impl AssemblyAIAdapter { pub fn is_supported_languages(_languages: &[hypr_language::Language]) -> bool { true } - - pub fn is_host(base_url: &str) -> bool { - super::host_matches(base_url, |h| h.contains("assemblyai.com")) - } } -const WS_PATH: &str = "/v3/ws"; - impl AssemblyAIAdapter { pub(crate) fn streaming_ws_url(api_base: &str) -> (url::Url, Vec<(String, String)>) { + use owhisper_providers::Provider; + if api_base.is_empty() { return ( - "wss://streaming.assemblyai.com/v3/ws" + Provider::AssemblyAI + .default_ws_url() .parse() .expect("invalid_default_ws_url"), Vec::new(), @@ -44,15 +41,19 @@ impl AssemblyAIAdapter { let existing_params = super::extract_query_params(&url); url.set_query(None); - super::append_path_if_missing(&mut url, WS_PATH); + super::append_path_if_missing(&mut url, Provider::AssemblyAI.ws_path()); super::set_scheme_from_host(&mut url); (url, existing_params) } pub(crate) fn batch_api_url(api_base: &str) -> url::Url { + use owhisper_providers::Provider; + if api_base.is_empty() { - return "https://api.assemblyai.com/v2" + return Provider::AssemblyAI + .default_api_url() + .unwrap() .parse() .expect("invalid_default_api_url"); } diff --git a/owhisper/owhisper-client/src/adapter/deepgram/live.rs b/owhisper/owhisper-client/src/adapter/deepgram/live.rs index ed51ebddc2..033ab83b73 100644 --- a/owhisper/owhisper-client/src/adapter/deepgram/live.rs +++ b/owhisper/owhisper-client/src/adapter/deepgram/live.rs @@ -29,7 +29,7 @@ impl RealtimeSttAdapter for DeepgramAdapter { } fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { - api_key.map(|key| ("Authorization", format!("Token {}", key))) + api_key.and_then(|k| owhisper_providers::Provider::Deepgram.build_auth_header(k)) } fn keep_alive_message(&self) -> Option { diff --git a/owhisper/owhisper-client/src/adapter/fireworks/live.rs b/owhisper/owhisper-client/src/adapter/fireworks/live.rs index f6d5da119b..6dd4eef26a 100644 --- a/owhisper/owhisper-client/src/adapter/fireworks/live.rs +++ b/owhisper/owhisper-client/src/adapter/fireworks/live.rs @@ -40,7 +40,7 @@ impl RealtimeSttAdapter for FireworksAdapter { } fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { - api_key.map(|key| ("Authorization", key.to_string())) + api_key.and_then(|k| owhisper_providers::Provider::Fireworks.build_auth_header(k)) } fn keep_alive_message(&self) -> Option { diff --git a/owhisper/owhisper-client/src/adapter/fireworks/mod.rs b/owhisper/owhisper-client/src/adapter/fireworks/mod.rs index 97a6b8f814..7bcaf45389 100644 --- a/owhisper/owhisper-client/src/adapter/fireworks/mod.rs +++ b/owhisper/owhisper-client/src/adapter/fireworks/mod.rs @@ -1,8 +1,7 @@ mod batch; mod live; -pub(crate) const DEFAULT_API_HOST: &str = "api.fireworks.ai"; -const WS_PATH: &str = "/v1/audio/transcriptions/streaming"; +use owhisper_providers::Provider; #[derive(Clone, Default)] pub struct FireworksAdapter; @@ -12,24 +11,18 @@ impl FireworksAdapter { true } - pub fn is_host(base_url: &str) -> bool { - super::host_matches(base_url, Self::is_fireworks_host) - } - pub(crate) fn api_host(api_base: &str) -> String { if api_base.is_empty() { - return DEFAULT_API_HOST.to_string(); + return Provider::Fireworks.default_api_host().to_string(); } let url: url::Url = match api_base.parse() { Ok(u) => u, - Err(_) => return DEFAULT_API_HOST.to_string(), + Err(_) => return Provider::Fireworks.default_api_host().to_string(), }; - url.host_str().unwrap_or(DEFAULT_API_HOST).to_string() - } - - pub(crate) fn is_fireworks_host(host: &str) -> bool { - host.contains("fireworks.ai") + url.host_str() + .unwrap_or(Provider::Fireworks.default_api_host()) + .to_string() } pub(crate) fn batch_api_host(api_base: &str) -> String { @@ -45,7 +38,8 @@ impl FireworksAdapter { pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) { let default_url = || { ( - format!("wss://audio-streaming-v2.{}{}", DEFAULT_API_HOST, WS_PATH) + Provider::Fireworks + .default_ws_url() .parse() .expect("invalid_default_ws_url"), Vec::new(), @@ -67,9 +61,13 @@ impl FireworksAdapter { let existing_params = super::extract_query_params(&parsed); - let url: url::Url = format!("wss://{}{}", Self::ws_host(api_base), WS_PATH) - .parse() - .expect("invalid_ws_url"); + let url: url::Url = format!( + "wss://{}{}", + Self::ws_host(api_base), + Provider::Fireworks.ws_path() + ) + .parse() + .expect("invalid_ws_url"); (url, existing_params) } } diff --git a/owhisper/owhisper-client/src/adapter/gladia/mod.rs b/owhisper/owhisper-client/src/adapter/gladia/mod.rs index 4ef15a611a..c84e5d659a 100644 --- a/owhisper/owhisper-client/src/adapter/gladia/mod.rs +++ b/owhisper/owhisper-client/src/adapter/gladia/mod.rs @@ -1,9 +1,7 @@ mod batch; mod live; -pub(crate) const DEFAULT_API_HOST: &str = "api.gladia.io"; -pub(crate) const WS_PATH: &str = "/v2/live"; -const API_BASE: &str = "https://api.gladia.io/v2"; +use owhisper_providers::Provider; #[derive(Clone, Default)] pub struct GladiaAdapter; @@ -13,14 +11,6 @@ impl GladiaAdapter { true } - pub fn is_host(base_url: &str) -> bool { - super::host_matches(base_url, Self::is_gladia_host) - } - - pub(crate) fn is_gladia_host(host: &str) -> bool { - host.contains("gladia.io") - } - pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) { if api_base.is_empty() { return (Self::default_ws_url(), Vec::new()); @@ -32,7 +22,7 @@ impl GladiaAdapter { let parsed: url::Url = api_base.parse().expect("invalid_api_base"); let existing_params = super::extract_query_params(&parsed); - let url = Self::build_url_with_scheme(&parsed, WS_PATH, true); + let url = Self::build_url_with_scheme(&parsed, Provider::Gladia.ws_path(), true); (url, existing_params) } @@ -42,11 +32,13 @@ impl GladiaAdapter { } let parsed: url::Url = api_base.parse().expect("invalid_api_base"); - Self::build_url_with_scheme(&parsed, WS_PATH, false) + Self::build_url_with_scheme(&parsed, Provider::Gladia.ws_path(), false) } fn build_url_with_scheme(parsed: &url::Url, path: &str, use_ws: bool) -> url::Url { - let host = parsed.host_str().unwrap_or(DEFAULT_API_HOST); + let host = parsed + .host_str() + .unwrap_or(Provider::Gladia.default_api_host()); let is_local = super::is_local_host(host); let scheme = match (use_ws, is_local) { (true, true) => "ws", @@ -64,20 +56,29 @@ impl GladiaAdapter { } fn default_ws_url() -> url::Url { - format!("wss://{}{}", DEFAULT_API_HOST, WS_PATH) + Provider::Gladia + .default_ws_url() .parse() .expect("invalid_default_ws_url") } fn default_http_url() -> url::Url { - format!("https://{}{}", DEFAULT_API_HOST, WS_PATH) - .parse() - .expect("invalid_default_http_url") + format!( + "https://{}{}", + Provider::Gladia.default_api_host(), + Provider::Gladia.ws_path() + ) + .parse() + .expect("invalid_default_http_url") } pub(crate) fn batch_api_url(api_base: &str) -> url::Url { if api_base.is_empty() { - return API_BASE.parse().expect("invalid_default_api_url"); + return Provider::Gladia + .default_api_url() + .unwrap() + .parse() + .expect("invalid_default_api_url"); } api_base.parse().expect("invalid_api_base") @@ -131,10 +132,10 @@ mod tests { #[test] fn test_is_host() { - assert!(GladiaAdapter::is_host("https://api.gladia.io")); - assert!(GladiaAdapter::is_host("https://api.gladia.io/v2")); - assert!(!GladiaAdapter::is_host("https://api.deepgram.com")); - assert!(!GladiaAdapter::is_host("https://api.assemblyai.com")); + assert!(Provider::Gladia.matches_url("https://api.gladia.io")); + assert!(Provider::Gladia.matches_url("https://api.gladia.io/v2")); + assert!(!Provider::Gladia.matches_url("https://api.deepgram.com")); + assert!(!Provider::Gladia.matches_url("https://api.assemblyai.com")); } #[test] diff --git a/owhisper/owhisper-client/src/adapter/mod.rs b/owhisper/owhisper-client/src/adapter/mod.rs index 023f302ea5..bf634fc37e 100644 --- a/owhisper/owhisper-client/src/adapter/mod.rs +++ b/owhisper/owhisper-client/src/adapter/mod.rs @@ -172,28 +172,41 @@ pub enum AdapterKind { Deepgram, AssemblyAI, OpenAI, + Gladia, } impl AdapterKind { pub fn from_url_and_languages(base_url: &str, languages: &[hypr_language::Language]) -> Self { + use owhisper_providers::Provider; + if is_hyprnote_cloud_host(base_url) { if DeepgramAdapter::is_supported_languages(languages) { - Self::Deepgram + return Self::Deepgram; } else { - Self::Soniox + return Self::Soniox; } - } else if is_local_stt_host(base_url) { - Self::Argmax - } else if AssemblyAIAdapter::is_host(base_url) { - Self::AssemblyAI - } else if SonioxAdapter::is_host(base_url) { - Self::Soniox - } else if FireworksAdapter::is_host(base_url) { - Self::Fireworks - } else if OpenAIAdapter::is_host(base_url) { - Self::OpenAI - } else { - Self::Deepgram + } + + if is_local_stt_host(base_url) { + return Self::Argmax; + } + + Provider::from_url(base_url) + .map(Self::from) + .unwrap_or(Self::Deepgram) + } +} + +impl From for AdapterKind { + fn from(p: owhisper_providers::Provider) -> Self { + use owhisper_providers::Provider; + match p { + Provider::Deepgram => Self::Deepgram, + Provider::AssemblyAI => Self::AssemblyAI, + Provider::Soniox => Self::Soniox, + Provider::Fireworks => Self::Fireworks, + Provider::OpenAI => Self::OpenAI, + Provider::Gladia => Self::Gladia, } } } diff --git a/owhisper/owhisper-client/src/adapter/openai/live.rs b/owhisper/owhisper-client/src/adapter/openai/live.rs index 2ef9041d77..884a07cf3d 100644 --- a/owhisper/owhisper-client/src/adapter/openai/live.rs +++ b/owhisper/owhisper-client/src/adapter/openai/live.rs @@ -36,7 +36,7 @@ impl RealtimeSttAdapter for OpenAIAdapter { } fn build_auth_header(&self, api_key: Option<&str>) -> Option<(&'static str, String)> { - api_key.map(|key| ("Authorization", format!("Bearer {}", key))) + api_key.and_then(|k| owhisper_providers::Provider::OpenAI.build_auth_header(k)) } fn keep_alive_message(&self) -> Option { diff --git a/owhisper/owhisper-client/src/adapter/openai/mod.rs b/owhisper/owhisper-client/src/adapter/openai/mod.rs index c6002b2baf..c3fedc3e16 100644 --- a/owhisper/owhisper-client/src/adapter/openai/mod.rs +++ b/owhisper/owhisper-client/src/adapter/openai/mod.rs @@ -1,14 +1,8 @@ mod batch; mod live; -pub(crate) const DEFAULT_WS_HOST: &str = "api.openai.com"; -pub(crate) const WS_PATH: &str = "/v1/realtime"; - -// OpenAI STT Models: -// - whisper-1: Legacy model, supports verbose_json with word timestamps (batch only) -// - gpt-4o-transcribe: High quality, supports both batch (json only) and realtime -// - gpt-4o-mini-transcribe: Cost-efficient, supports both batch (json only) and realtime -// - gpt-4o-transcribe-diarize: Speaker diarization (batch only, not yet supported here) +use owhisper_providers::Provider; + pub(crate) const DEFAULT_TRANSCRIPTION_MODEL: &str = "gpt-4o-transcribe"; #[derive(Clone, Default)] @@ -19,18 +13,11 @@ impl OpenAIAdapter { true } - pub fn is_host(base_url: &str) -> bool { - super::host_matches(base_url, Self::is_openai_host) - } - - pub(crate) fn is_openai_host(host: &str) -> bool { - host.contains("openai.com") - } - pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) { if api_base.is_empty() { return ( - format!("wss://{}{}", DEFAULT_WS_HOST, WS_PATH) + Provider::OpenAI + .default_ws_url() .parse() .expect("invalid_default_ws_url"), vec![("intent".to_string(), "transcription".to_string())], @@ -48,8 +35,10 @@ impl OpenAIAdapter { existing_params.push(("intent".to_string(), "transcription".to_string())); } - let host = parsed.host_str().unwrap_or(DEFAULT_WS_HOST); - let mut url: url::Url = format!("wss://{}{}", host, WS_PATH) + let host = parsed + .host_str() + .unwrap_or(Provider::OpenAI.default_ws_host()); + let mut url: url::Url = format!("wss://{}{}", host, Provider::OpenAI.ws_path()) .parse() .expect("invalid_ws_url"); @@ -91,8 +80,8 @@ mod tests { #[test] fn test_is_openai_host() { - assert!(OpenAIAdapter::is_openai_host("api.openai.com")); - assert!(OpenAIAdapter::is_openai_host("openai.com")); - assert!(!OpenAIAdapter::is_openai_host("api.deepgram.com")); + assert!(Provider::OpenAI.is_host("api.openai.com")); + assert!(Provider::OpenAI.is_host("openai.com")); + assert!(!Provider::OpenAI.is_host("api.deepgram.com")); } } diff --git a/owhisper/owhisper-client/src/adapter/soniox/mod.rs b/owhisper/owhisper-client/src/adapter/soniox/mod.rs index 0d48e82ccd..7b4450911c 100644 --- a/owhisper/owhisper-client/src/adapter/soniox/mod.rs +++ b/owhisper/owhisper-client/src/adapter/soniox/mod.rs @@ -1,9 +1,6 @@ mod batch; mod live; -pub(crate) const DEFAULT_API_HOST: &str = "api.soniox.com"; -pub(crate) const DEFAULT_WS_HOST: &str = "stt-rt.soniox.com"; - #[derive(Clone, Default)] pub struct SonioxAdapter; @@ -12,39 +9,38 @@ impl SonioxAdapter { true } - pub fn is_host(base_url: &str) -> bool { - super::host_matches(base_url, Self::is_soniox_host) - } - pub(crate) fn api_host(api_base: &str) -> String { + use owhisper_providers::Provider; + + let default_host = Provider::Soniox.default_api_host(); + if api_base.is_empty() { - return DEFAULT_API_HOST.to_string(); + return default_host.to_string(); } let url: url::Url = api_base.parse().expect("invalid_api_base"); - url.host_str().unwrap_or(DEFAULT_API_HOST).to_string() - } - - pub(crate) fn is_soniox_host(host: &str) -> bool { - host.contains("soniox.com") + url.host_str().unwrap_or(default_host).to_string() } pub(crate) fn ws_host(api_base: &str) -> String { + use owhisper_providers::Provider; + let api_host = Self::api_host(api_base); if let Some(rest) = api_host.strip_prefix("api.") { format!("stt-rt.{}", rest) } else { - DEFAULT_WS_HOST.to_string() + Provider::Soniox.default_ws_host().to_string() } } pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) { - const WS_PATH: &str = "/transcribe-websocket"; + use owhisper_providers::Provider; if api_base.is_empty() { return ( - format!("wss://{}{}", DEFAULT_WS_HOST, WS_PATH) + Provider::Soniox + .default_ws_url() .parse() .expect("invalid_default_ws_url"), Vec::new(), @@ -58,9 +54,13 @@ impl SonioxAdapter { let parsed: url::Url = api_base.parse().expect("invalid_api_base"); let existing_params = super::extract_query_params(&parsed); - let url: url::Url = format!("wss://{}{}", Self::ws_host(api_base), WS_PATH) - .parse() - .expect("invalid_ws_url"); + let url: url::Url = format!( + "wss://{}{}", + Self::ws_host(api_base), + Provider::Soniox.ws_path() + ) + .parse() + .expect("invalid_ws_url"); (url, existing_params) } } diff --git a/owhisper/owhisper-providers/Cargo.toml b/owhisper/owhisper-providers/Cargo.toml new file mode 100644 index 0000000000..1e00c5314e --- /dev/null +++ b/owhisper/owhisper-providers/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "owhisper-providers" +version = "0.0.1" +edition = "2024" + +[dependencies] +serde_json = { workspace = true } +strum = { workspace = true, features = ["derive"] } +url = { workspace = true } diff --git a/owhisper/owhisper-providers/src/lib.rs b/owhisper/owhisper-providers/src/lib.rs new file mode 100644 index 0000000000..724b9806e6 --- /dev/null +++ b/owhisper/owhisper-providers/src/lib.rs @@ -0,0 +1,209 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Auth { + Header { + name: &'static str, + prefix: Option<&'static str>, + }, + FirstMessage { + field_name: &'static str, + }, + SessionInit { + header_name: &'static str, + }, +} + +impl Auth { + pub fn build_header(&self, api_key: &str) -> Option<(&'static str, String)> { + match self { + Auth::Header { name, prefix } => { + let value = match prefix { + Some(p) => format!("{}{}", p, api_key), + None => api_key.to_string(), + }; + Some((name, value)) + } + Auth::FirstMessage { .. } | Auth::SessionInit { .. } => None, + } + } + + pub fn build_session_init_header(&self, api_key: &str) -> Option<(&'static str, String)> { + match self { + Auth::SessionInit { header_name } => Some((header_name, api_key.to_string())), + _ => None, + } + } + + pub fn transform_first_message(&self, payload: String, api_key: &str) -> String { + match self { + Auth::FirstMessage { field_name } => { + match serde_json::from_str::(&payload) { + Ok(mut json) => { + if let Some(obj) = json.as_object_mut() { + obj.insert( + (*field_name).to_string(), + serde_json::Value::String(api_key.to_string()), + ); + } + serde_json::to_string(&json).unwrap_or(payload) + } + Err(_) => payload, + } + } + Auth::Header { .. } | Auth::SessionInit { .. } => payload, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::EnumString, strum::Display)] +#[strum(serialize_all = "lowercase")] +pub enum Provider { + Deepgram, + #[strum(serialize = "assemblyai")] + AssemblyAI, + Soniox, + Fireworks, + #[strum(serialize = "openai")] + OpenAI, + Gladia, +} + +impl Provider { + const ALL: [Provider; 6] = [ + Self::Deepgram, + Self::AssemblyAI, + Self::Soniox, + Self::Fireworks, + Self::OpenAI, + Self::Gladia, + ]; + + pub fn from_host(host: &str) -> Option { + Self::ALL.into_iter().find(|p| p.is_host(host)) + } + + pub fn auth(&self) -> Auth { + match self { + Self::Deepgram => Auth::Header { + name: "Authorization", + prefix: Some("Token "), + }, + Self::AssemblyAI => Auth::Header { + name: "Authorization", + prefix: None, + }, + Self::Fireworks => Auth::Header { + name: "Authorization", + prefix: None, + }, + Self::OpenAI => Auth::Header { + name: "Authorization", + prefix: Some("Bearer "), + }, + Self::Gladia => Auth::SessionInit { + header_name: "x-gladia-key", + }, + Self::Soniox => Auth::FirstMessage { + field_name: "api_key", + }, + } + } + + pub fn build_auth_header(&self, api_key: &str) -> Option<(&'static str, String)> { + self.auth().build_header(api_key) + } + + pub fn default_ws_url(&self) -> String { + format!("wss://{}{}", self.default_ws_host(), self.ws_path()) + } + + pub fn default_api_host(&self) -> &'static str { + match self { + Self::Deepgram => "api.deepgram.com", + Self::AssemblyAI => "api.assemblyai.com", + Self::Soniox => "api.soniox.com", + Self::Fireworks => "api.fireworks.ai", + Self::OpenAI => "api.openai.com", + Self::Gladia => "api.gladia.io", + } + } + + pub fn default_ws_host(&self) -> &'static str { + match self { + Self::Deepgram => "api.deepgram.com", + Self::AssemblyAI => "streaming.assemblyai.com", + Self::Soniox => "stt-rt.soniox.com", + Self::Fireworks => "audio-streaming-v2.api.fireworks.ai", + Self::OpenAI => "api.openai.com", + Self::Gladia => "api.gladia.io", + } + } + + pub fn ws_path(&self) -> &'static str { + match self { + Self::Deepgram => "/v1/listen", + Self::AssemblyAI => "/v3/ws", + Self::Soniox => "/transcribe-websocket", + Self::Fireworks => "/v1/audio/transcriptions/streaming", + Self::OpenAI => "/v1/realtime", + Self::Gladia => "/v2/live", + } + } + + pub fn default_api_url(&self) -> Option<&'static str> { + match self { + Self::Deepgram => None, + Self::AssemblyAI => Some("https://api.assemblyai.com/v2"), + Self::Soniox => None, + Self::Fireworks => None, + Self::OpenAI => None, + Self::Gladia => Some("https://api.gladia.io/v2"), + } + } + + pub fn domain(&self) -> &'static str { + match self { + Self::Deepgram => "deepgram.com", + Self::AssemblyAI => "assemblyai.com", + Self::Soniox => "soniox.com", + Self::Fireworks => "fireworks.ai", + Self::OpenAI => "openai.com", + Self::Gladia => "gladia.io", + } + } + + pub fn is_host(&self, host: &str) -> bool { + let domain = self.domain(); + host == domain || host.ends_with(&format!(".{}", domain)) + } + + pub fn matches_url(&self, base_url: &str) -> bool { + url::Url::parse(base_url) + .ok() + .and_then(|u| u.host_str().map(|h| self.is_host(h))) + .unwrap_or(false) + } + + pub fn from_url(base_url: &str) -> Option { + url::Url::parse(base_url) + .ok() + .and_then(|u| u.host_str().and_then(Self::from_host)) + } + + pub fn env_key_name(&self) -> &'static str { + match self { + Self::Deepgram => "DEEPGRAM_API_KEY", + Self::AssemblyAI => "ASSEMBLYAI_API_KEY", + Self::Soniox => "SONIOX_API_KEY", + Self::Fireworks => "FIREWORKS_API_KEY", + Self::OpenAI => "OPENAI_API_KEY", + Self::Gladia => "GLADIA_API_KEY", + } + } + + pub fn default_query_params(&self) -> &'static [(&'static str, &'static str)] { + match self { + Self::Deepgram => &[("model", "nova-3-general"), ("mip_opt_out", "false")], + _ => &[], + } + } +} diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 44962107b7..14a897b1d2 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -6,7 +6,7 @@ use tokio::time::error::Elapsed; use owhisper_client::{ AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, DeepgramAdapter, FinalizeHandle, - FireworksAdapter, OpenAIAdapter, RealtimeSttAdapter, SonioxAdapter, + FireworksAdapter, GladiaAdapter, OpenAIAdapter, RealtimeSttAdapter, SonioxAdapter, }; use owhisper_interface::stream::{Extra, StreamResponse}; use owhisper_interface::{ControlMessage, MixedMessage}; @@ -240,6 +240,12 @@ async fn spawn_rx_task( (AdapterKind::OpenAI, true) => { spawn_rx_task_dual_with_adapter::(args, myself).await } + (AdapterKind::Gladia, false) => { + spawn_rx_task_single_with_adapter::(args, myself).await + } + (AdapterKind::Gladia, true) => { + spawn_rx_task_dual_with_adapter::(args, myself).await + } } } diff --git a/plugins/listener2/src/batch.rs b/plugins/listener2/src/batch.rs index db7c61cace..63ff8d25a1 100644 --- a/plugins/listener2/src/batch.rs +++ b/plugins/listener2/src/batch.rs @@ -5,7 +5,7 @@ use std::time::Duration; use futures_util::StreamExt; use owhisper_client::{ AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, DeepgramAdapter, FireworksAdapter, - OpenAIAdapter, RealtimeSttAdapter, SonioxAdapter, + GladiaAdapter, OpenAIAdapter, RealtimeSttAdapter, SonioxAdapter, }; use owhisper_interface::stream::StreamResponse; use owhisper_interface::{ControlMessage, MixedMessage}; @@ -218,6 +218,7 @@ async fn spawn_batch_task( spawn_batch_task_with_adapter::(args, myself).await } AdapterKind::OpenAI => spawn_batch_task_with_adapter::(args, myself).await, + AdapterKind::Gladia => spawn_batch_task_with_adapter::(args, myself).await, } }