-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtemperature_fallback.rs
More file actions
257 lines (210 loc) · 9.07 KB
/
temperature_fallback.rs
File metadata and controls
257 lines (210 loc) · 9.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
//! Example demonstrating temperature fallback for improved transcription quality
use std::path::{Path, PathBuf};
use whisper_cpp_plus::{WhisperContext, TranscriptionParams, FullParams, SamplingStrategy};
use whisper_cpp_plus::enhanced::fallback::{
EnhancedTranscriptionParams, EnhancedTranscriptionParamsBuilder,
QualityThresholds, EnhancedWhisperState
};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let model_path = find_model("ggml-tiny.en.bin")
.ok_or("Model not found. Run: cargo xtask test-setup")?;
println!("Loading model from {:?}...", model_path);
let ctx = WhisperContext::new(&model_path)?;
// Load audio (you would load real audio here)
let (clear_audio, noisy_audio) = load_audio_examples()?;
// Example 1: Standard transcription vs Enhanced with fallback
println!("\n=== Example 1: Clear Audio ===");
compare_transcription_methods(&ctx, &clear_audio)?;
// Example 2: Noisy/difficult audio
println!("\n=== Example 2: Noisy/Difficult Audio ===");
compare_transcription_methods(&ctx, &noisy_audio)?;
// Example 3: Custom quality thresholds
println!("\n=== Example 3: Custom Quality Thresholds ===");
demonstrate_custom_thresholds(&ctx, &noisy_audio)?;
// Example 4: Direct enhanced state usage
println!("\n=== Example 4: Direct Enhanced State Control ===");
demonstrate_direct_enhanced_state(&ctx, &noisy_audio)?;
Ok(())
}
fn compare_transcription_methods(
ctx: &WhisperContext,
audio: &[f32]
) -> Result<(), Box<dyn std::error::Error>> {
// Standard transcription
println!("1. Standard transcription:");
let start = std::time::Instant::now();
let standard_text = ctx.transcribe(audio)?;
let standard_time = start.elapsed();
println!(" Text: {}", standard_text);
println!(" Time: {:?}", standard_time);
// Enhanced transcription with automatic fallback
println!("\n2. Enhanced transcription with temperature fallback:");
let params = TranscriptionParams::builder()
.language("en")
.build();
let start = std::time::Instant::now();
let enhanced_result = ctx.transcribe_with_params_enhanced(audio, params)?;
let enhanced_time = start.elapsed();
println!(" Text: {}", enhanced_result.text);
println!(" Time: {:?}", enhanced_time);
if enhanced_result.text != standard_text {
println!(" Note: Enhanced version produced different (likely better) result!");
}
Ok(())
}
fn demonstrate_custom_thresholds(
ctx: &WhisperContext,
audio: &[f32]
) -> Result<(), Box<dyn std::error::Error>> {
println!("Creating enhanced parameters with custom quality thresholds...");
// Build custom enhanced parameters
let base_params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
.language("en");
let enhanced_params = EnhancedTranscriptionParamsBuilder::new()
.base_params(base_params)
.temperatures(vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
.compression_ratio_threshold(Some(2.0)) // Stricter than default 2.4
.log_prob_threshold(Some(-0.5)) // Stricter than default -1.0
.build();
println!("Quality thresholds:");
println!(" - Max compression ratio: {:?}", enhanced_params.thresholds.compression_ratio_threshold);
println!(" - Min log probability: {:?}", enhanced_params.thresholds.log_prob_threshold);
println!(" - Temperature sequence: {:?}", enhanced_params.temperatures);
// Transcribe with custom thresholds
let mut state = ctx.create_state()?;
let mut enhanced_state = EnhancedWhisperState::new(&mut state);
let result = enhanced_state.transcribe_with_fallback(enhanced_params, audio)?;
println!("\nTranscription result:");
println!(" Text: {}", result.text);
println!(" Segments: {}", result.segments.len());
for (i, segment) in result.segments.iter().enumerate() {
println!(" Segment {}: [{:.2}s - {:.2}s] {}",
i + 1,
segment.start_seconds(),
segment.end_seconds(),
segment.text
);
}
Ok(())
}
fn demonstrate_direct_enhanced_state(
ctx: &WhisperContext,
audio: &[f32]
) -> Result<(), Box<dyn std::error::Error>> {
println!("Using enhanced state directly for fine control...");
// Create state once and reuse
let mut state = ctx.create_state()?;
// Configure different quality thresholds for experimentation
let relaxed_thresholds = QualityThresholds {
compression_ratio_threshold: Some(3.0), // More relaxed
log_prob_threshold: Some(-2.0), // More relaxed
no_speech_threshold: Some(0.8),
};
let strict_thresholds = QualityThresholds {
compression_ratio_threshold: Some(1.5), // Very strict
log_prob_threshold: Some(-0.3), // Very strict
no_speech_threshold: Some(0.4),
};
// Try with relaxed thresholds
println!("\n1. With relaxed thresholds:");
let params = EnhancedTranscriptionParams {
base: FullParams::default().language("en"),
temperatures: vec![0.0, 0.5, 1.0],
thresholds: relaxed_thresholds,
prompt_reset_on_temperature: 0.5,
};
let mut enhanced_state = EnhancedWhisperState::new(&mut state);
let result = enhanced_state.transcribe_with_fallback(params, audio)?;
println!(" Result: {}", result.text);
// Try with strict thresholds
println!("\n2. With strict thresholds:");
let params = EnhancedTranscriptionParams {
base: FullParams::default().language("en"),
temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
thresholds: strict_thresholds,
prompt_reset_on_temperature: 0.5,
};
let result = enhanced_state.transcribe_with_fallback(params, audio)?;
println!(" Result: {}", result.text);
println!(" Note: Stricter thresholds may have triggered more temperature fallbacks");
Ok(())
}
fn load_audio_examples() -> Result<(Vec<f32>, Vec<f32>), Box<dyn std::error::Error>> {
// Check env var first
let jfk_from_env = std::env::var("WHISPER_TEST_AUDIO_DIR")
.ok()
.map(|d| format!("{}/jfk.wav", d))
.filter(|p| Path::new(p).exists());
let jfk_paths = [
"../whisper-cpp-plus-sys/whisper.cpp/samples/jfk.wav",
"whisper-cpp-plus-sys/whisper.cpp/samples/jfk.wav",
"samples/clear_speech.wav",
];
let clear_audio = if let Some(ref p) = jfk_from_env {
println!("Loading clear audio from: {}", p);
load_wav_file(p)?
} else if let Some(path) = jfk_paths.iter().find(|p| Path::new(p).exists()) {
println!("Loading clear audio from: {}", path);
load_wav_file(path)?
} else {
eprintln!("\nError: No audio files found!");
eprintln!("Set WHISPER_TEST_AUDIO_DIR env var or provide audio.");
return Err("No audio files found".into());
};
// For noisy audio, try to load from file or create a noisy version
let noisy_path = "samples/noisy_speech.wav";
let noisy_audio = if Path::new(noisy_path).exists() {
println!("Loading noisy audio from: {}", noisy_path);
load_wav_file(noisy_path)?
} else {
println!("Creating noisy version from clear audio for demonstration...");
add_noise_to_audio(&clear_audio)
};
Ok((clear_audio, noisy_audio))
}
fn add_noise_to_audio(audio: &[f32]) -> Vec<f32> {
// Add noise to existing audio to simulate noisy conditions
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
let mut rng = RandomState::new().build_hasher();
audio.iter().enumerate().map(|(i, &sample)| {
// Simple pseudo-random noise generation
i.hash(&mut rng);
let noise_val = (rng.finish() as f32 / u64::MAX as f32 - 0.5) * 0.15; // Lower noise level
let noisy = sample + noise_val;
noisy.max(-1.0).min(1.0) // Clip to valid range
}).collect()
}
fn load_wav_file(path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
use hound;
let mut reader = hound::WavReader::open(path)?;
let spec = reader.spec();
// Check format
if spec.sample_rate != 16000 {
eprintln!("Warning: Audio sample rate is {}Hz, expected 16000Hz", spec.sample_rate);
}
if spec.channels != 1 {
eprintln!("Warning: Audio has {} channels, using first channel only", spec.channels);
}
let samples: Vec<f32> = reader
.samples::<i16>()
.step_by(spec.channels as usize)
.map(|s| s.unwrap() as f32 / 32768.0)
.collect();
Ok(samples)
}
fn find_model(name: &str) -> Option<PathBuf> {
for env_var in ["WHISPER_TEST_MODEL_DIR", "WHISPER_MODEL_PATH"] {
if let Ok(dir) = std::env::var(env_var) {
let path = Path::new(&dir).join(name);
if path.exists() { return Some(path); }
}
}
let paths = [
format!("tests/models/{}", name),
format!("whisper-cpp-plus/tests/models/{}", name),
format!("../whisper-cpp-plus-sys/whisper.cpp/models/{}", name),
format!("whisper-cpp-plus-sys/whisper.cpp/models/{}", name),
];
paths.iter().find(|p| Path::new(p).exists()).map(PathBuf::from)
}