Skip to content

Adding Realtime ASR Client #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
steps:
- run:
name: "Install build dependencies"
command: "sudo apt-get --allow-releaseinfo-change update && sudo apt-get install -y wget libasound2-dev libopus-dev libopusfile-dev"
command: "sudo apt-get --allow-releaseinfo-change update && sudo apt-get install -y wget libasound2-dev libopus-dev libopusfile-dev libboost-all-dev"
- run:
name: "Install bazel"
command: "wget https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-linux-amd64 && sudo mv bazelisk-linux-amd64 /usr/local/bin/bazelisk && sudo chmod +x /usr/local/bin/bazelisk"
Expand Down
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ RUN apt-get update && apt-get install -y \
libasound2t64 \
libogg0 \
openssl \
ca-certificates
ca-certificates \
libboost-all-dev

FROM base AS builddep
ARG BAZEL_VERSION
Expand Down Expand Up @@ -67,4 +68,5 @@ COPY --from=builder /opt/riva/clients/nlp/riva_nlp_punct /usr/local/bin/
COPY --from=builder /opt/riva/clients/nmt/riva_nmt_t2t_client /usr/local/bin/
COPY --from=builder /opt/riva/clients/nmt/riva_nmt_streaming_s2t_client /usr/local/bin/
COPY --from=builder /opt/riva/clients/nmt/riva_nmt_streaming_s2s_client /usr/local/bin/
COPY --from=builder /opt/riva/clients/realtime/riva_realtime_asr_client /usr/local/bin/
COPY examples /work/examples
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ NVIDIA Riva is a GPU-accelerated SDK for building Speech AI applications that ar
- **Automatic Speech Recognition (ASR)**
- `riva_streaming_asr_client`
- `riva_asr_client`
- `riva_realtime_asr_client`
- **Speech Synthesis (TTS)**
- `riva_tts_client`
- `riva_tts_perf_client`
Expand Down Expand Up @@ -73,6 +74,7 @@ You can find the built binaries in `bazel-bin/riva/clients`
Riva comes with 2 ASR clients:
1. `riva_asr_client` for offline usage. Using this client, the server will wait until it receives the full audio file before transcribing it and sending it back to the client.
2. `riva_streaming_asr_client` for online usage. Using this client, the server will start transcribing after it receives a sufficient amount of audio data, "streaming" intermediate transcripts as it goes on back to the client. By default, it is set to transcribe after every `100ms`, this can be changed using the `--chunk_duration_ms` command line flag.
3. `riva_realtime_asr_client` for realtime (websocket) usage. This client establishes a persistent websocket connection to the server, allowing for bidirectional real-time communication. The server will start transcribing after it receives a sufficient amount of audio data and continuously stream intermediate transcripts back to the client as it processes the audio. By default, it is set to transcribe after every `100ms`, which can be changed using the `--chunk_duration_ms` command line flag.

To use the clients, simply pass in a folder containing audio files or an individual audio file name with the `audio_file` flag:
```
Expand All @@ -82,6 +84,10 @@ or
```
$ riva_asr_client --audio_file audio_folder
```
or
```
$ riva_realtime_asr_client --audio_file individual_audio_file.wav
```

Note that only single-channel audio files in the `.wav` format are currently supported.

Expand Down
8 changes: 8 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,11 @@ http_archive(
strip_prefix = "platforms-1.0.0",
sha256 = "852b71bfa15712cec124e4a57179b6bc95d59fdf5052945f5d550e072501a769",
)

http_archive(
name = "websocketpp",
urls = ["https://github.com/zaphoyd/websocketpp/archive/refs/tags/0.8.2.tar.gz"],
sha256 = "6ce889d85ecdc2d8fa07408d6787e7352510750daa66b5ad44aacb47bea76755",
strip_prefix = "websocketpp-0.8.2",
build_file = "//third_party:BUILD.websocketpp"
)
59 changes: 59 additions & 0 deletions riva/clients/realtime/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
"""

package(
default_visibility = ["//visibility:public"],
)

cc_library(
name = "realtime_audio_client_lib",
srcs = [
"audio_chunks.cpp",
"base_client.cpp",
"recognition_client.cpp",
],
hdrs = [
"audio_chunks.h",
"base_client.h",
"recognition_client.h",
],
deps = [
"//riva/utils/wav:reader",
"//riva/utils/stats_builder:stats_builder_lib",
"@websocketpp//:websocketpp",
"@rapidjson//:rapidjson",
"@glog//:glog",
"@com_github_gflags_gflags//:gflags",
],
)

cc_binary(
name = "riva_realtime_asr_client",
srcs = ["riva_realtime_asr_client.cc"],
includes = ["-Irealtime"],
deps = [
":realtime_audio_client_lib",
"@websocketpp//:websocketpp",
"@rapidjson//:rapidjson",
"//riva/utils/stats_builder:stats_builder_lib",
"//riva/utils/wav:reader",
] + select({
"@platforms//cpu:aarch64": [
"@alsa_aarch64//:libasound"
],
"//conditions:default": [
"@alsa//:libasound"
],
}),
linkopts = [
"-lssl",
"-lcrypto",
"-lboost_system",
]
)
160 changes: 160 additions & 0 deletions riva/clients/realtime/audio_chunks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: MIT
*/

#include "audio_chunks.h"
#include "riva/utils/wav/wav_reader.h"
#include "riva/utils/wav/wav_data.h"
#include <filesystem>
#include <fstream>
#include <iostream>
#include <sstream>
#include <iomanip>
#include <numeric>
#include <algorithm>
#include <future>
#include <unistd.h>
#include <vector>

nvidia::riva::realtime::AudioChunks::AudioChunks(const std::string& filepath, const int& chunk_size_ms)
: filepath_(filepath), chunk_size_ms_(chunk_size_ms) {
}

void nvidia::riva::realtime::AudioChunks::CalculateChunkSizeBytes() {
chunk_size_bytes_ = (GetSampleRateHz() * GetChunkSizeMs() / 1000) * sizeof(int16_t);
std::cout << "[AudioChunks] Calculated chunk size: " << chunk_size_bytes_ << " bytes" << std::endl;
}

void nvidia::riva::realtime::AudioChunks::SplitIntoChunks() {
const std::vector<char>& raw_data = wav_data_->data;
size_t total_size = raw_data.size();

std::cout << "[AudioChunks] Splitting WAV file into chunks of " << chunk_size_bytes_ << " bytes" << std::endl;

chunk_base64s_.clear();
for (size_t i = 0; i < total_size; i += chunk_size_bytes_) {
size_t current_chunk_size = std::min(chunk_size_bytes_, total_size - i);
std::vector<char> chunk(raw_data.begin() + i, raw_data.begin() + i + current_chunk_size);
std::string chunk_base64 = EncodeBase64(chunk);
chunk_base64s_.push_back(chunk_base64);
}
}

std::string nvidia::riva::realtime::AudioChunks::EncodeBase64(const std::vector<char>& data) {
const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";

std::string result;
int val = 0, valb = -6;

for (unsigned char c : data) {
val = (val << 8) + c;
valb += 8;
while (valb >= 0) {
result.push_back(base64_chars[(val >> valb) & 0x3F]);
valb -= 6;
}
}

if (valb > -6) {
result.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]);
}

while (result.size() % 4) {
result.push_back('=');
}

return result;
}

bool nvidia::riva::realtime::AudioChunks::Init() {
if (initialized_) {
std::cout << "[AudioChunks] Chunks already initialized" << std::endl;
return true;
}

std::cout << "[AudioChunks] Initializing chunks for file: " << filepath_ << std::endl;
fs::path path(filepath_);
std::string extension = path.extension().string();

// File exists
if (!fs::exists(filepath_)) {
std::cerr << "[AudioChunks] Error: File does not exist, " << filepath_ << std::endl;
return false;
}

// File is a WAV file
if (extension != ".wav") {
std::cerr << "[AudioChunks] Error: File is not a WAV file, " << filepath_ << std::endl;
return false;
}

// Load WAV file using the existing WAV utilities
std::vector<std::shared_ptr<WaveData>> all_wav;
LoadWavData(all_wav, filepath_);

if (all_wav.empty()) {
std::cerr << "[AudioChunks] Error: Failed to load WAV file, " << filepath_ << std::endl;
return false;
}

wav_data_ = all_wav[0]; // Use the first WAV file

CalculateChunkSizeBytes();
SplitIntoChunks();

initialized_ = true;

return initialized_;
}

// Getter implementations
std::string nvidia::riva::realtime::AudioChunks::GetFilepath() const {
return filepath_;
}

size_t nvidia::riva::realtime::AudioChunks::GetChunkSizeMs() const {
return chunk_size_ms_;
}

size_t nvidia::riva::realtime::AudioChunks::GetChunkSizeBytes() const {
return chunk_size_bytes_;
}

bool nvidia::riva::realtime::AudioChunks::IsInitialized() const {
return initialized_;
}

// WAV file properties
int nvidia::riva::realtime::AudioChunks::GetSampleRateHz() const {
return wav_data_->sample_rate;
}

int nvidia::riva::realtime::AudioChunks::GetNumChannels() const {
return wav_data_->channels;
}

int nvidia::riva::realtime::AudioChunks::GetBitDepth() const {
// Calculate bit depth from data size and sample rate
if (wav_data_->channels > 0 && wav_data_->sample_rate > 0) {
return (wav_data_->data.size() * 8) / (wav_data_->channels * wav_data_->sample_rate);
}
return 16; // Default to 16-bit
}

double nvidia::riva::realtime::AudioChunks::GetDurationSeconds() const {
if (wav_data_->sample_rate > 0 && wav_data_->channels > 0) {
return static_cast<double>(wav_data_->data.size()) / (wav_data_->sample_rate * wav_data_->channels * 2); // Assuming 16-bit
}
return 0.0;
}

int nvidia::riva::realtime::AudioChunks::GetNumSamples() const {
if (wav_data_->channels > 0) {
return wav_data_->data.size() / (wav_data_->channels * 2); // Assuming 16-bit
}
return 0;
}
84 changes: 84 additions & 0 deletions riva/clients/realtime/audio_chunks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: MIT
*/

#ifndef AUDIO_CHUNKS_H
#define AUDIO_CHUNKS_H

#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <filesystem>
#include "riva/utils/wav/wav_reader.h"
#include "riva/utils/wav/wav_data.h"

namespace fs = std::filesystem;

namespace nvidia::riva::realtime {
class AudioChunks {
private:
bool initialized_ = false;
std::string filepath_;
size_t chunk_size_ms_;
size_t chunk_size_bytes_;
std::shared_ptr<WaveData> wav_data_;
std::vector<std::string> chunk_base64s_;

void CalculateChunkSizeBytes();
void SplitIntoChunks();
std::string EncodeBase64(const std::vector<char>& data);

public:
AudioChunks(const std::string& filepath, const int& chunk_size_ms);
~AudioChunks() = default;

bool Init();

// Getters
std::string GetFilepath() const;
size_t GetChunkSizeMs() const;
size_t GetChunkSizeBytes() const;
bool IsInitialized() const;

// WAV file properties
int GetSampleRateHz() const;
int GetNumChannels() const;
int GetBitDepth() const;
double GetDurationSeconds() const;
int GetNumSamples() const;
const std::vector<std::string>& GetChunkBase64s() const;

// Iterator support
using iterator = std::vector<std::string>::iterator;
using const_iterator = std::vector<std::string>::const_iterator;
using reverse_iterator = std::vector<std::string>::reverse_iterator;
using const_reverse_iterator = std::vector<std::string>::const_reverse_iterator;

// Iterator methods
iterator begin() { return chunk_base64s_.begin(); }
const_iterator begin() const { return chunk_base64s_.begin(); }
iterator end() { return chunk_base64s_.end(); }
const_iterator end() const { return chunk_base64s_.end(); }

// Reverse iterator methods
reverse_iterator rbegin() { return chunk_base64s_.rbegin(); }
const_reverse_iterator rbegin() const { return chunk_base64s_.rbegin(); }
reverse_iterator rend() { return chunk_base64s_.rend(); }
const_reverse_iterator rend() const { return chunk_base64s_.rend(); }

// Const iterator methods
const_iterator cbegin() const { return chunk_base64s_.cbegin(); }
const_iterator cend() const { return chunk_base64s_.cend(); }
const_reverse_iterator crbegin() const { return chunk_base64s_.crbegin(); }
const_reverse_iterator crend() const { return chunk_base64s_.crend(); }

// Size methods
size_t size() const { return chunk_base64s_.size(); }
bool empty() const { return chunk_base64s_.empty(); }
};

} // namespace nvidia::riva::realtime

#endif // AUDIO_CHUNKS_H
Loading