Skip to content

Commit 0c170c6

Browse files
feat: groq and openai both working
1 parent af4f9b2 commit 0c170c6

File tree

7 files changed

+48
-73
lines changed

7 files changed

+48
-73
lines changed

Cargo.lock

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ categories = ["api-bindings", "web-programming", "asynchronous"]
1414

1515

1616
[dependencies]
17-
reqwest = { version = "0.12", features = ["json"] }
18-
serde = { version = "1.0.203", features = ["derive"] }
19-
tokio = { version = "1", features = ["full"] }
20-
serde_json = "1.0"
21-
anyhow = "1.0.86"
17+
reqwest = { version = "0.12.0", features = ["multipart", "json"] }
18+
serde = { version = "1.0.197", features = ["derive"] }
19+
tokio = { version = "1.36.0", features = ["full"] }
20+
serde_json = "1.0.114"
21+
anyhow = "1.0.81"
2222
chrono = { version = "0.4.38", features = ["serde"] }
2323
reqwest-eventsource = "0.6.0"
2424
futures = "0.3.30"
@@ -33,10 +33,11 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
3333
tracing-appender = "0.2.2"
3434
hound = "3.5.1"
3535
bytes = { version = "1.10.1", features = ["serde"] }
36-
clap = { version = "4.3", features = ["derive"] }
36+
clap = { version = "4.5", features = ["derive"] }
3737
global-hotkey = "0.6"
3838
colored = "3.0"
3939
bytemuck = { version = "1.15.0", features = ["derive"] } # Added for safe byte casting
4040
dialoguer = "0.11.0"
4141
thiserror = "2.0.12"
4242
rubato = "0.16.2"
43+
groq-api-rust = "0.2.51"

scripts/local_run.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ export RUST_BACKTRACE=1
66

77
cargo build
88

9-
./target/debug/groq-api-rs
9+
# ./target/debug/groq-api-rs -t groq
10+
./target/debug/groq-api-rs -t open-ai

src/app.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ impl App {
3939
let shutdown_manager = Arc::new(ShutdownManager::new("tts-groq"));
4040

4141
let provider: Box<dyn TranscriptionProvider + Send + Sync> = match config.transcription_provider {
42-
TranscriptionProviderAPI::Groq => Box::new(providers::GroqProvider::new()),
43-
TranscriptionProviderAPI::OpenAI => Box::new(providers::OpenAIProvider::new()),
42+
TranscriptionProviderAPI::Groq => Box::new(providers::GroqProvider::new().await),
43+
TranscriptionProviderAPI::OpenAI => Box::new(providers::OpenAIProvider::new().await),
4444
};
4545

4646
Ok(Self {

src/providers/async_openai_self.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub struct OpenAIProvider {
1010
}
1111

1212
impl OpenAIProvider {
13-
pub fn new() -> Self {
13+
pub async fn new() -> Self {
1414
Self {
1515
client: async_openai::Client::new(),
1616
}

src/providers/groq.rs

Lines changed: 14 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,20 @@
22
use crate::providers::TranscriptionProvider;
33
use anyhow::{Context, Result};
44
use async_trait::async_trait;
5-
use reqwest::{multipart, Client};
6-
use serde::Deserialize;
5+
use groq_api_rust::{AsyncGroqClient, SpeechToTextRequest};
76
use std::env;
87
use std::time::Duration;
9-
use tracing::{debug, error, info, instrument};
10-
11-
#[derive(Debug, Deserialize)]
12-
struct GroqTranscriptionResponse {
13-
text: String,
14-
}
8+
use tracing::{debug, info, instrument};
159

1610
pub struct GroqProvider {
17-
client: Client,
18-
api_key: String,
11+
client: AsyncGroqClient,
1912
}
2013

2114
impl GroqProvider {
22-
pub fn new() -> Self {
15+
pub async fn new() -> Self {
2316
let api_key = env::var("GROQ_API_KEY").expect("GROQ_API_KEY environment variable not set");
24-
2517
Self {
26-
client: Client::new(),
27-
api_key,
18+
client: AsyncGroqClient::new(api_key, None).await,
2819
}
2920
}
3021
}
@@ -44,48 +35,17 @@ impl TranscriptionProvider for GroqProvider {
4435
info!("Starting Groq transcription request");
4536
debug!("Audio data size: {} bytes", audio_data.len());
4637

47-
// Create multipart form with audio data
48-
let form = multipart::Form::new()
49-
.part(
50-
"file",
51-
multipart::Part::bytes(audio_data.to_vec())
52-
.file_name("audio.wav")
53-
.mime_str("audio/wav")?,
54-
)
55-
.text("model", "whisper-large-v3");
38+
let request = SpeechToTextRequest::new(audio_data.to_vec())
39+
.temperature(0.7)
40+
.language("en")
41+
.model("whisper-large-v3");
5642

57-
// Send request to Groq API
58-
let response = self
59-
.client
60-
.post("https://api.groq.com/v1/audio/transcriptions")
61-
.header("Authorization", format!("Bearer {}", self.api_key))
62-
.multipart(form)
63-
.send()
64-
.await
65-
.context("Failed to send request to Groq API")?;
43+
let response = self.client.speech_to_text(request).await
44+
.context("Failed to get response from Groq")?;
6645

67-
// Parse response
68-
if response.status().is_success() {
69-
let transcription: GroqTranscriptionResponse = response
70-
.json()
71-
.await
72-
.context("Failed to parse Groq API response")?;
46+
info!("Groq transcription successful");
47+
debug!("Received transcription: {}", response.text);
7348

74-
info!("Groq transcription successful");
75-
debug!(
76-
"Received transcription with {} characters",
77-
transcription.text.len()
78-
);
79-
80-
Ok(transcription.text)
81-
} else {
82-
let error_text = response
83-
.text()
84-
.await
85-
.unwrap_or_else(|_| "Unable to get error message".to_string());
86-
87-
error!("Groq transcription failed: {}", error_text);
88-
Err(anyhow::anyhow!("Groq API error: {}", error_text))
89-
}
49+
Ok(response.text)
9050
}
9151
}

src/providers/mod.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,26 @@ impl TranscriptionProvider for MockProvider {
5353
}
5454
}
5555

56-
pub fn create_provider(provider_name: &str) -> Box<dyn TranscriptionProvider + Send + Sync> {
57-
info!("Creating transcription provider: {}", provider_name);
56+
/// Create a new transcription provider based on the provider name
57+
pub async fn create_provider(provider_name: &str) -> Box<dyn TranscriptionProvider + Send + Sync> {
58+
debug!("Creating provider: {}", provider_name);
5859

5960
match provider_name.to_lowercase().as_str() {
6061
"openai" => {
6162
debug!("Initializing OpenAI provider");
62-
Box::new(OpenAIProvider::new())
63+
Box::new(OpenAIProvider::new().await)
6364
}
6465
"groq" => {
6566
debug!("Initializing Groq provider");
66-
Box::new(self::groq::GroqProvider::new())
67+
Box::new(GroqProvider::new().await)
6768
}
6869
"mock" => {
6970
debug!("Initializing Mock provider");
7071
Box::new(MockProvider::new())
7172
}
7273
_ => {
73-
warn!(
74-
"Unknown provider: {}, falling back to OpenAI",
75-
provider_name
76-
);
77-
Box::new(OpenAIProvider::new())
74+
warn!("Unknown provider '{}', using mock provider", provider_name);
75+
Box::new(MockProvider::new())
7876
}
7977
}
8078
}

0 commit comments

Comments
 (0)