Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 8 additions & 23 deletions owhisper/owhisper-client/src/adapter/gladia/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct FileInfo {
audio_duration: Option<f64>,
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Default, Deserialize)]
struct TranscriptResult {
#[serde(default)]
metadata: Option<ResultMetadata>,
Expand All @@ -87,7 +87,7 @@ struct ResultMetadata {
audio_duration: Option<f64>,
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Default, Deserialize)]
struct Transcription {
#[serde(default)]
full_transcript: Option<String>,
Expand Down Expand Up @@ -185,18 +185,10 @@ impl GladiaAdapter {
.map(|l| l.iso639().code().to_string())
.collect();

let language_config = if languages.is_empty() {
None
} else {
Some(LanguageConfig {
languages,
code_switching: if params.languages.len() > 1 {
Some(true)
} else {
None
},
})
};
let language_config = (!languages.is_empty()).then(|| LanguageConfig {
languages,
code_switching: (params.languages.len() > 1).then_some(true),
});

let transcript_request = TranscriptRequest {
audio_url: upload_result.audio_url,
Expand Down Expand Up @@ -270,15 +262,8 @@ impl GladiaAdapter {
}

fn convert_to_batch_response(response: TranscriptResponse) -> BatchResponse {
let result = response.result.unwrap_or(TranscriptResult {
metadata: None,
transcription: None,
});

let transcription = result.transcription.unwrap_or(Transcription {
full_transcript: None,
utterances: Vec::new(),
});
let result = response.result.unwrap_or_default();
let transcription = result.transcription.unwrap_or_default();

let words: Vec<BatchWord> = transcription
.utterances
Expand Down
134 changes: 59 additions & 75 deletions owhisper/owhisper-client/src/adapter/gladia/live.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,37 @@ use super::GladiaAdapter;
use crate::adapter::parsing::WordBuilder;
use crate::adapter::RealtimeSttAdapter;

fn session_channels() -> &'static Mutex<HashMap<String, u8>> {
static SESSION_CHANNELS: OnceLock<Mutex<HashMap<String, u8>>> = OnceLock::new();
SESSION_CHANNELS.get_or_init(|| Mutex::new(HashMap::new()))
struct SessionChannels;

impl SessionChannels {
fn store() -> &'static Mutex<HashMap<String, u8>> {
static SESSION_CHANNELS: OnceLock<Mutex<HashMap<String, u8>>> = OnceLock::new();
SESSION_CHANNELS.get_or_init(|| Mutex::new(HashMap::new()))
}

fn insert(session_id: String, channels: u8) {
if let Ok(mut map) = Self::store().lock() {
map.insert(session_id, channels);
}
}

fn get(session_id: &str) -> Option<u8> {
Self::store()
.lock()
.ok()
.and_then(|map| map.get(session_id).copied())
}

fn remove(session_id: &str) -> Option<u8> {
Self::store()
.lock()
.ok()
.and_then(|mut map| map.remove(session_id))
}

fn get_or_infer(session_id: &str, channel_idx: i32) -> u8 {
Self::get(session_id).unwrap_or_else(|| (channel_idx + 1).max(1) as u8)
}
}

impl RealtimeSttAdapter for GladiaAdapter {
Expand Down Expand Up @@ -56,41 +84,17 @@ impl RealtimeSttAdapter for GladiaAdapter {
}

let key = api_key?;
let post_url = Self::build_http_url(api_base);

let post_url = if api_base.is_empty() {
"https://api.gladia.io/v2/live".to_string()
} else {
let parsed: url::Url = api_base.parse().ok()?;
let host = parsed.host_str().unwrap_or("api.gladia.io");
let scheme = if crate::adapter::is_local_host(host) {
"http"
} else {
"https"
};
let host_with_port = match parsed.port() {
Some(port) => format!("{host}:{port}"),
None => host.to_string(),
};
format!("{scheme}://{host_with_port}/v2/live")
};

let language_config = if params.languages.is_empty() {
None
} else {
Some(LanguageConfig {
languages: params
.languages
.iter()
.map(|l| l.iso639().code().to_string())
.collect(),
})
};
let language_config = (!params.languages.is_empty()).then(|| LanguageConfig {
languages: params
.languages
.iter()
.map(|l| l.iso639().code().to_string())
.collect(),
});

let custom_vocabulary = if params.keywords.is_empty() {
None
} else {
Some(params.keywords.clone())
};
let custom_vocabulary = (!params.keywords.is_empty()).then(|| params.keywords.clone());

let body = GladiaConfig {
encoding: "wav/pcm",
Expand All @@ -104,39 +108,30 @@ impl RealtimeSttAdapter for GladiaAdapter {
}),
};

let body_json = match serde_json::to_value(&body) {
Ok(v) => v,
Err(e) => {
let body_json = serde_json::to_value(&body)
.map_err(|e| {
tracing::error!(error = ?e, "gladia_init_serialize_failed");
return None;
}
};
})
.ok()?;

let resp = match ureq::post(&post_url)
let resp = ureq::post(post_url.as_str())
.set("x-gladia-key", key)
.set("Content-Type", "application/json")
.send_json(body_json)
{
Ok(r) => r,
Err(e) => {
.map_err(|e| {
tracing::error!(error = ?e, "gladia_init_request_failed");
return None;
}
};
})
.ok()?;

let init: InitResponse = match resp.into_json() {
Ok(r) => r,
Err(e) => {
let init: InitResponse = resp
.into_json()
.map_err(|e| {
tracing::error!(error = ?e, "gladia_init_parse_failed");
return None;
}
};
})
.ok()?;

tracing::debug!(session_id = %init.id, url = %init.url, channels = channels, "gladia_session_initialized");

if let Ok(mut map) = session_channels().lock() {
map.insert(init.id.clone(), channels);
}
SessionChannels::insert(init.id.clone(), channels);

url::Url::parse(&init.url).ok()
}
Expand Down Expand Up @@ -178,14 +173,10 @@ impl RealtimeSttAdapter for GladiaAdapter {
vec![]
}
GladiaMessage::EndSession { id } => {
let channels = session_channels()
.lock()
.ok()
.and_then(|mut map| map.remove(&id))
.unwrap_or_else(|| {
tracing::warn!(session_id = %id, "gladia_session_channels_not_found");
1
});
let channels = SessionChannels::remove(&id).unwrap_or_else(|| {
tracing::warn!(session_id = %id, "gladia_session_channels_not_found");
1
});
tracing::debug!(session_id = %id, channels = channels, "gladia_session_ended");
vec![StreamResponse::TerminalResponse {
request_id: id,
Expand Down Expand Up @@ -371,14 +362,7 @@ impl GladiaAdapter {
};

let channel_idx = utterance.channel.unwrap_or(0);
let total_channels = session_channels()
.lock()
.ok()
.and_then(|map| map.get(&session_id).copied())
.unwrap_or_else(|| {
// Fallback: infer from channel_idx if session not found
(channel_idx + 1).max(1) as u8
});
let total_channels = SessionChannels::get_or_infer(&session_id, channel_idx);

vec![StreamResponse::TranscriptResponse {
is_final,
Expand Down
49 changes: 34 additions & 15 deletions owhisper/owhisper-client/src/adapter/gladia/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ impl GladiaAdapter {

pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) {
if api_base.is_empty() {
return (
format!("wss://{}{}", DEFAULT_API_HOST, WS_PATH)
.parse()
.expect("invalid_default_ws_url"),
Vec::new(),
);
return (Self::default_ws_url(), Vec::new());
}

if let Some(proxy_result) = super::build_proxy_ws_url(api_base) {
Expand All @@ -37,31 +32,55 @@ impl GladiaAdapter {

let parsed: url::Url = api_base.parse().expect("invalid_api_base");
let existing_params = super::extract_query_params(&parsed);
let url = Self::build_url_with_scheme(&parsed, WS_PATH, true);
(url, existing_params)
}

pub(crate) fn build_http_url(api_base: &str) -> url::Url {
if api_base.is_empty() {
return Self::default_http_url();
}

let parsed: url::Url = api_base.parse().expect("invalid_api_base");
Self::build_url_with_scheme(&parsed, WS_PATH, false)
}

fn build_url_with_scheme(parsed: &url::Url, path: &str, use_ws: bool) -> url::Url {
let host = parsed.host_str().unwrap_or(DEFAULT_API_HOST);
let scheme = if super::is_local_host(host) {
"ws"
} else {
"wss"
let is_local = super::is_local_host(host);
let scheme = match (use_ws, is_local) {
(true, true) => "ws",
(true, false) => "wss",
(false, true) => "http",
(false, false) => "https",
};
let host_with_port = match parsed.port() {
Some(port) => format!("{host}:{port}"),
None => host.to_string(),
};
format!("{scheme}://{host_with_port}{path}")
.parse()
.expect("invalid_url")
}

let url: url::Url = format!("{scheme}://{host_with_port}{WS_PATH}")
fn default_ws_url() -> url::Url {
format!("wss://{}{}", DEFAULT_API_HOST, WS_PATH)
.parse()
.expect("invalid_ws_url");
(url, existing_params)
.expect("invalid_default_ws_url")
}

fn default_http_url() -> url::Url {
format!("https://{}{}", DEFAULT_API_HOST, WS_PATH)
.parse()
.expect("invalid_default_http_url")
}

pub(crate) fn batch_api_url(api_base: &str) -> url::Url {
if api_base.is_empty() {
return API_BASE.parse().expect("invalid_default_api_url");
}

let url: url::Url = api_base.parse().expect("invalid_api_base");
url
api_base.parse().expect("invalid_api_base")
}
}

Expand Down