diff --git a/sdk/runanywhere-commons/src/backends/whispercpp/whispercpp_backend.cpp b/sdk/runanywhere-commons/src/backends/whispercpp/whispercpp_backend.cpp index e5ab18fe9..f51b3f803 100644 --- a/sdk/runanywhere-commons/src/backends/whispercpp/whispercpp_backend.cpp +++ b/sdk/runanywhere-commons/src/backends/whispercpp/whispercpp_backend.cpp @@ -274,13 +274,21 @@ STTResult WhisperCppSTT::transcribe_internal(const std::vector& audio, const int n_segments = whisper_full_n_segments(ctx_); std::string full_text; + full_text.reserve(n_segments * 64); + + result.segments.reserve(n_segments); + + if (word_timestamps) { + result.word_timings.reserve(n_segments * 15); + } for (int i = 0; i < n_segments; ++i) { const char* text = whisper_full_get_segment_text(ctx_, i); if (text) { full_text += text; - AudioSegment segment; + result.segments.emplace_back(); + AudioSegment& segment = result.segments.back(); segment.text = text; segment.start_time_ms = whisper_full_get_segment_t0(ctx_, i) * 10.0; segment.end_time_ms = whisper_full_get_segment_t1(ctx_, i) * 10.0; @@ -288,8 +296,6 @@ STTResult WhisperCppSTT::transcribe_internal(const std::vector& audio, float no_speech_prob = whisper_full_get_segment_no_speech_prob(ctx_, i); segment.confidence = 1.0f - no_speech_prob; - result.segments.push_back(segment); - if (word_timestamps) { const int n_tokens = whisper_full_n_tokens(ctx_, i); for (int j = 0; j < n_tokens; ++j) { @@ -297,12 +303,12 @@ STTResult WhisperCppSTT::transcribe_internal(const std::vector& audio, const char* token_text = whisper_full_get_token_text(ctx_, i, j); if (token_text && token_text[0] != '\0' && token_text[0] != '<') { - WordTiming word; + result.word_timings.emplace_back(); + WordTiming& word = result.word_timings.back(); word.word = token_text; word.start_time_ms = token_data.t0 * 10.0; word.end_time_ms = token_data.t1 * 10.0; word.confidence = token_data.p; - result.word_timings.push_back(word); } } } @@ -547,22 +553,64 @@ std::vector WhisperCppSTT::get_supported_languages() const { std::vector WhisperCppSTT::resample_to_16khz(const std::vector& samples, int source_rate) { - if (source_rate == WHISPER_SAMPLE_RATE) { + if (source_rate == WHISPER_SAMPLE_RATE || samples.empty()) { return samples; } - const double ratio = static_cast(WHISPER_SAMPLE_RATE) / source_rate; - const size_t output_size = static_cast(samples.size() * ratio); + const double step = static_cast(source_rate) / WHISPER_SAMPLE_RATE; + + size_t output_size = static_cast(samples.size() / step); + if (output_size == 0) { + output_size = 1; + } + + std::vector output; + + if (source_rate % WHISPER_SAMPLE_RATE == 0) { + const int stride = source_rate / WHISPER_SAMPLE_RATE; + const size_t out_len = std::max(1, samples.size() / stride); + + output.resize(out_len); + for (size_t i = 0; i < out_len; ++i) { + output[i] = samples[i * stride]; + } + return output; + } + + output.resize(output_size); + + const float* __restrict src_ptr = samples.data(); + const size_t src_size = samples.size(); + + const size_t safe_output_limit = (output_size > 0) ? output_size - 1 : 0; + + double pos = 0.0; + size_t i = 0; + + for (; i < safe_output_limit; ++i) { + size_t idx0 = static_cast(pos); + if (idx0 >= src_size - 1) break; + + double frac = pos - idx0; + float val0 = src_ptr[idx0]; + float val1 = src_ptr[idx0 + 1]; + + output[i] = val0 + static_cast(frac) * (val1 - val0); + pos += step; + } + + for (; i < output_size; ++i) { + size_t idx0 = static_cast(pos); + if (idx0 >= src_size) idx0 = src_size - 1; - std::vector output(output_size); + size_t idx1 = (idx0 + 1 < src_size) ? idx0 + 1 : src_size - 1; - for (size_t i = 0; i < output_size; ++i) { - const double src_idx = i / ratio; - const size_t idx0 = static_cast(src_idx); - const size_t idx1 = std::min(idx0 + 1, samples.size() - 1); - const double frac = src_idx - idx0; + double frac = pos - static_cast(idx0); + float val0 = src_ptr[idx0]; + float val1 = src_ptr[idx1]; - output[i] = static_cast(samples[idx0] * (1.0 - frac) + samples[idx1] * frac); + output[i] = val0 + static_cast(frac) * (val1 - val0); + pos += step; } LOGI("Resampled audio from %d Hz to %d Hz (%zu -> %zu samples)", source_rate,