Skip to content

Commit fcd12f5

Browse files
authored
Add owhisper-providers and share it in client adapters and proxy server (#2396)
1 parent 01a236a commit fcd12f5

File tree

23 files changed

+1108
-141
lines changed

23 files changed

+1108
-141
lines changed

Cargo.lock

Lines changed: 325 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ owhisper-client = { path = "owhisper/owhisper-client", package = "owhisper-clien
8989
owhisper-config = { path = "owhisper/owhisper-config", package = "owhisper-config" }
9090
owhisper-interface = { path = "owhisper/owhisper-interface", package = "owhisper-interface" }
9191
owhisper-model = { path = "owhisper/owhisper-model", package = "owhisper-model" }
92+
owhisper-providers = { path = "owhisper/owhisper-providers", package = "owhisper-providers" }
9293

9394
tauri = "2.9"
9495
tauri-build = "2.5"
@@ -165,7 +166,7 @@ clap = "4"
165166
codes-iso-639 = "0.1.5"
166167
derive_more = "2"
167168
dirs = "6.0.0"
168-
dotenv = "0.15.0"
169+
dotenvy = "0.15.7"
169170
include_url_macro = "0.1.0"
170171
indoc = "2"
171172
itertools = "0.14.0"
@@ -205,6 +206,7 @@ async-openai = { git = "https://github.com/fastrepl/async-openai", rev = "6404d3
205206
async-stripe = { version = "0.39.1", default-features = false }
206207
gbnf-validator = { git = "https://github.com/fastrepl/gbnf-validator", rev = "3dec055" }
207208

209+
jsonwebtoken = { version = "10", features = ["rust_crypto"] }
208210
sentry = "0.42"
209211
vergen-gix = "1"
210212

apps/stt/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,19 @@ path = "src/main.rs"
99

1010
[dependencies]
1111
hypr-transcribe-proxy = { workspace = true }
12+
owhisper-providers = { workspace = true }
1213

1314
axum = { workspace = true, features = ["ws"] }
1415
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
1516
tower-http = { workspace = true, features = ["trace"] }
1617
tracing = { workspace = true }
1718
tracing-subscriber = { workspace = true, features = ["env-filter"] }
1819

20+
reqwest = { workspace = true, features = ["json"] }
21+
serde = { workspace = true, features = ["derive"] }
22+
serde_json = { workspace = true }
23+
url = { workspace = true }
24+
25+
dotenvy = { workspace = true }
26+
jsonwebtoken = { workspace = true }
1927
sentry = { workspace = true }

apps/stt/src/auth.rs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
use std::sync::Arc;
2+
use std::time::{Duration, Instant};
3+
4+
use axum::{
5+
extract::FromRequestParts,
6+
http::{StatusCode, request::Parts},
7+
};
8+
use jsonwebtoken::{DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
9+
use serde::Deserialize;
10+
use tokio::sync::RwLock;
11+
12+
use crate::env::env;
13+
14+
const ENTITLEMENT_PRO: &str = "hyprnote_pro";
15+
const JWKS_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
16+
17+
#[derive(Debug, Clone)]
18+
pub struct AuthUser {
19+
pub user_id: String,
20+
pub entitlements: Vec<String>,
21+
}
22+
23+
impl AuthUser {
24+
pub fn is_pro(&self) -> bool {
25+
self.entitlements.iter().any(|e| e == ENTITLEMENT_PRO)
26+
}
27+
}
28+
29+
#[derive(Debug, Deserialize)]
30+
struct Claims {
31+
sub: String,
32+
#[serde(default)]
33+
entitlements: Vec<String>,
34+
}
35+
36+
struct CachedJwks {
37+
jwks: JwkSet,
38+
fetched_at: Instant,
39+
}
40+
41+
static JWKS_CACHE: std::sync::OnceLock<Arc<RwLock<Option<CachedJwks>>>> =
42+
std::sync::OnceLock::new();
43+
44+
fn jwks_cache() -> &'static Arc<RwLock<Option<CachedJwks>>> {
45+
JWKS_CACHE.get_or_init(|| Arc::new(RwLock::new(None)))
46+
}
47+
48+
async fn get_jwks() -> Result<JwkSet, &'static str> {
49+
let cache = jwks_cache();
50+
51+
{
52+
let guard = cache.read().await;
53+
if let Some(cached) = guard.as_ref() {
54+
if cached.fetched_at.elapsed() < JWKS_CACHE_TTL {
55+
return Ok(cached.jwks.clone());
56+
}
57+
}
58+
}
59+
60+
let env = env();
61+
let jwks_url = format!("{}/auth/v1/.well-known/jwks.json", env.supabase_url);
62+
63+
let jwks: JwkSet = reqwest::get(&jwks_url)
64+
.await
65+
.map_err(|_| "failed to fetch jwks")?
66+
.json()
67+
.await
68+
.map_err(|_| "failed to parse jwks")?;
69+
70+
{
71+
let mut guard = cache.write().await;
72+
*guard = Some(CachedJwks {
73+
jwks: jwks.clone(),
74+
fetched_at: Instant::now(),
75+
});
76+
}
77+
78+
Ok(jwks)
79+
}
80+
81+
impl<S> FromRequestParts<S> for AuthUser
82+
where
83+
S: Send + Sync,
84+
{
85+
type Rejection = (StatusCode, &'static str);
86+
87+
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
88+
let auth_header = parts
89+
.headers
90+
.get("authorization")
91+
.and_then(|v| v.to_str().ok())
92+
.ok_or((StatusCode::UNAUTHORIZED, "missing authorization header"))?;
93+
94+
let token = auth_header
95+
.strip_prefix("Bearer ")
96+
.or_else(|| auth_header.strip_prefix("bearer "))
97+
.ok_or((StatusCode::UNAUTHORIZED, "invalid authorization header"))?;
98+
99+
let header =
100+
decode_header(token).map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token header"))?;
101+
102+
let kid = header
103+
.kid
104+
.as_ref()
105+
.ok_or((StatusCode::UNAUTHORIZED, "missing kid in token"))?;
106+
107+
let jwks = get_jwks()
108+
.await
109+
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
110+
111+
let jwk = jwks
112+
.find(kid)
113+
.ok_or((StatusCode::UNAUTHORIZED, "unknown signing key"))?;
114+
115+
let key = DecodingKey::from_jwk(jwk)
116+
.map_err(|_| (StatusCode::UNAUTHORIZED, "invalid signing key"))?;
117+
118+
let alg = jwk
119+
.common
120+
.key_algorithm
121+
.and_then(|a| a.to_string().parse().ok())
122+
.ok_or((StatusCode::UNAUTHORIZED, "unsupported algorithm"))?;
123+
124+
let mut validation = Validation::new(alg);
125+
validation.set_audience(&["authenticated"]);
126+
validation.validate_exp = true;
127+
128+
let token_data = decode::<Claims>(token, &key, &validation)
129+
.map_err(|_| (StatusCode::UNAUTHORIZED, "invalid token"))?;
130+
131+
Ok(AuthUser {
132+
user_id: token_data.claims.sub,
133+
entitlements: token_data.claims.entitlements,
134+
})
135+
}
136+
}

apps/stt/src/env.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use std::collections::HashMap;
2+
use std::sync::OnceLock;
3+
4+
use owhisper_providers::Provider;
5+
6+
pub struct Env {
7+
pub port: u16,
8+
pub sentry_dsn: Option<String>,
9+
pub supabase_url: String,
10+
api_keys: HashMap<Provider, String>,
11+
}
12+
13+
static ENV: OnceLock<Env> = OnceLock::new();
14+
15+
pub fn env() -> &'static Env {
16+
ENV.get_or_init(|| {
17+
let _ = dotenvy::dotenv();
18+
Env::from_env()
19+
})
20+
}
21+
22+
impl Env {
23+
fn from_env() -> Self {
24+
let providers = [
25+
Provider::Deepgram,
26+
Provider::AssemblyAI,
27+
Provider::Soniox,
28+
Provider::Fireworks,
29+
Provider::OpenAI,
30+
Provider::Gladia,
31+
];
32+
let api_keys = providers
33+
.into_iter()
34+
.map(|p| (p, required(p.env_key_name())))
35+
.collect();
36+
37+
Self {
38+
port: parse_or("PORT", 3000),
39+
sentry_dsn: optional("SENTRY_DSN"),
40+
supabase_url: required("SUPABASE_URL"),
41+
api_keys,
42+
}
43+
}
44+
45+
pub fn api_key_for(&self, provider: Provider) -> String {
46+
self.api_keys
47+
.get(&provider)
48+
.cloned()
49+
.unwrap_or_else(|| panic!("{} is not configured", provider.env_key_name()))
50+
}
51+
}
52+
53+
fn required(key: &str) -> String {
54+
std::env::var(key).unwrap_or_else(|_| panic!("{key} is required"))
55+
}
56+
57+
fn optional(key: &str) -> Option<String> {
58+
std::env::var(key).ok().filter(|s| !s.is_empty())
59+
}
60+
61+
fn parse_or<T: std::str::FromStr>(key: &str, default: T) -> T {
62+
std::env::var(key)
63+
.ok()
64+
.and_then(|v| v.parse().ok())
65+
.unwrap_or(default)
66+
}

0 commit comments

Comments
 (0)