diff --git a/.circleci/config.yml b/.circleci/config.yml index 238bafe..5103715 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ jobs: steps: - run: name: "Install build dependencies" - command: "sudo apt-get --allow-releaseinfo-change update && sudo apt-get install -y wget libasound2-dev libopus-dev libopusfile-dev" + command: "sudo apt-get --allow-releaseinfo-change update && sudo apt-get install -y wget libasound2-dev libopus-dev libopusfile-dev libboost-all-dev" - run: name: "Install bazel" command: "wget https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-linux-amd64 && sudo mv bazelisk-linux-amd64 /usr/local/bin/bazelisk && sudo chmod +x /usr/local/bin/bazelisk" diff --git a/Dockerfile b/Dockerfile index a7c6e0a..73da6b6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,8 @@ RUN apt-get update && apt-get install -y \ libasound2t64 \ libogg0 \ openssl \ - ca-certificates + ca-certificates \ + libboost-all-dev FROM base AS builddep ARG BAZEL_VERSION @@ -67,4 +68,5 @@ COPY --from=builder /opt/riva/clients/nlp/riva_nlp_punct /usr/local/bin/ COPY --from=builder /opt/riva/clients/nmt/riva_nmt_t2t_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/nmt/riva_nmt_streaming_s2t_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/nmt/riva_nmt_streaming_s2s_client /usr/local/bin/ +COPY --from=builder /opt/riva/clients/realtime/riva_realtime_asr_client /usr/local/bin/ COPY examples /work/examples diff --git a/README.md b/README.md index 6092fe8..6301d40 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ NVIDIA Riva is a GPU-accelerated SDK for building Speech AI applications that ar - **Automatic Speech Recognition (ASR)** - `riva_streaming_asr_client` - `riva_asr_client` + - `riva_realtime_asr_client` - **Speech Synthesis (TTS)** - `riva_tts_client` - `riva_tts_perf_client` @@ -73,6 +74,7 @@ You can find the built binaries in `bazel-bin/riva/clients` Riva comes with 2 ASR clients: 1. `riva_asr_client` for offline usage. Using this client, the server will wait until it receives the full audio file before transcribing it and sending it back to the client. 2. `riva_streaming_asr_client` for online usage. Using this client, the server will start transcribing after it receives a sufficient amount of audio data, "streaming" intermediate transcripts as it goes on back to the client. By default, it is set to transcribe after every `100ms`, this can be changed using the `--chunk_duration_ms` command line flag. +3. `riva_realtime_asr_client` for realtime (websocket) usage. This client establishes a persistent websocket connection to the server, allowing for bidirectional real-time communication. The server will start transcribing after it receives a sufficient amount of audio data and continuously stream intermediate transcripts back to the client as it processes the audio. By default, it is set to transcribe after every `100ms`, which can be changed using the `--chunk_duration_ms` command line flag. To use the clients, simply pass in a folder containing audio files or an individual audio file name with the `audio_file` flag: ``` @@ -82,6 +84,10 @@ or ``` $ riva_asr_client --audio_file audio_folder ``` +or +``` +$ riva_realtime_asr_client --audio_file individual_audio_file.wav +``` Note that only single-channel audio files in the `.wav` format are currently supported. diff --git a/WORKSPACE b/WORKSPACE index ac622ea..9331c0f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -102,3 +102,11 @@ http_archive( strip_prefix = "platforms-1.0.0", sha256 = "852b71bfa15712cec124e4a57179b6bc95d59fdf5052945f5d550e072501a769", ) + +http_archive( + name = "websocketpp", + urls = ["https://github.com/zaphoyd/websocketpp/archive/refs/tags/0.8.2.tar.gz"], + sha256 = "6ce889d85ecdc2d8fa07408d6787e7352510750daa66b5ad44aacb47bea76755", + strip_prefix = "websocketpp-0.8.2", + build_file = "//third_party:BUILD.websocketpp" +) \ No newline at end of file diff --git a/riva/clients/realtime/BUILD b/riva/clients/realtime/BUILD new file mode 100644 index 0000000..3d7d4fe --- /dev/null +++ b/riva/clients/realtime/BUILD @@ -0,0 +1,59 @@ +""" +Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +NVIDIA CORPORATION and its licensors retain all intellectual property +and proprietary rights in and to this software, related documentation +and any modifications thereto. Any use, reproduction, disclosure or +distribution of this software and related documentation without an express +license agreement from NVIDIA CORPORATION is strictly prohibited. +""" + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "realtime_audio_client_lib", + srcs = [ + "audio_chunks.cpp", + "base_client.cpp", + "realtime_client.cpp", + ], + hdrs = [ + "audio_chunks.h", + "base_client.h", + "realtime_client.h", + ], + deps = [ + "//riva/utils/wav:reader", + "//riva/utils/stats_builder:stats_builder_lib", + "@websocketpp//:websocketpp", + "@rapidjson//:rapidjson", + "@glog//:glog", + "@com_github_gflags_gflags//:gflags", + ], +) + +cc_binary( + name = "riva_realtime_asr_client", + srcs = ["riva_realtime_asr_client.cc"], + includes = ["-Irealtime"], + deps = [ + ":realtime_audio_client_lib", + "@websocketpp//:websocketpp", + "@rapidjson//:rapidjson", + "//riva/utils/stats_builder:stats_builder_lib", + "//riva/utils/wav:reader", + ] + select({ + "@platforms//cpu:aarch64": [ + "@alsa_aarch64//:libasound" + ], + "//conditions:default": [ + "@alsa//:libasound" + ], + }), + linkopts = [ + "-lssl", + "-lcrypto", + "-lboost_system", + ] +) \ No newline at end of file diff --git a/riva/clients/realtime/audio_chunks.cpp b/riva/clients/realtime/audio_chunks.cpp new file mode 100644 index 0000000..3a36b7f --- /dev/null +++ b/riva/clients/realtime/audio_chunks.cpp @@ -0,0 +1,511 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include "audio_chunks.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "riva/utils/wav/wav_data.h" +#include "riva/utils/wav/wav_reader.h" + +namespace nvidia::riva::realtime { + +// ============================================================================ +// Base AudioChunks class implementation +// ============================================================================ + +AudioChunks::AudioChunks(const int& chunk_size_ms) : chunk_size_ms_(chunk_size_ms) {} + +void +AudioChunks::CalculateChunkSizeBytes(int sample_rate) +{ + chunk_size_bytes_ = (sample_rate * chunk_size_ms_ / 1000) * sizeof(int16_t); + std::cout << "[AudioChunks] Calculated chunk size: " << chunk_size_bytes_ << " bytes" + << std::endl; +} + +std::string +AudioChunks::EncodeBase64(const std::vector& data) +{ + const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + std::string result; + int val = 0, valb = -6; + + for (unsigned char c : data) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + result.push_back(base64_chars[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { + result.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]); + } + + while (result.size() % 4) { + result.push_back('='); + } + + return result; +} + +bool +AudioChunks::Init() +{ + if (initialized_) { + std::cout << "[AudioChunks] Chunks already initialized" << std::endl; + return true; + } + + std::cout << "[AudioChunks] Initializing audio chunks..." << std::endl; + + if (!InitializeAudio()) { + std::cerr << "[AudioChunks] Error: Failed to initialize audio" << std::endl; + return false; + } + + ProcessAudioData(); + + initialized_ = true; + std::cout << "[AudioChunks] Successfully initialized with " << chunk_base64s_.size() << " chunks" + << std::endl; + + return initialized_; +} + +// Getter implementations +size_t +AudioChunks::GetChunkSizeMs() const +{ + return chunk_size_ms_; +} + +size_t +AudioChunks::GetChunkSizeBytes() const +{ + return chunk_size_bytes_; +} + +bool +AudioChunks::IsInitialized() const +{ + return initialized_; +} + +const std::vector& +AudioChunks::GetChunkBase64s() const +{ + return chunk_base64s_; +} + +// ============================================================================ +// FileAudioChunks derived class implementation +// ============================================================================ + +FileAudioChunks::FileAudioChunks(const std::string& filepath, const int& chunk_size_ms) + : AudioChunks(chunk_size_ms), filepath_(filepath) +{ +} + +void +FileAudioChunks::SplitIntoChunks() +{ + const std::vector& raw_data = wav_data_->data; + size_t total_size = raw_data.size(); + + std::cout << "[FileAudioChunks] Splitting WAV file into chunks of " << chunk_size_bytes_ + << " bytes" << std::endl; + + chunk_base64s_.clear(); + for (size_t i = 0; i < total_size; i += chunk_size_bytes_) { + size_t current_chunk_size = std::min(chunk_size_bytes_, total_size - i); + std::vector chunk(raw_data.begin() + i, raw_data.begin() + i + current_chunk_size); + std::string chunk_base64 = EncodeBase64(chunk); + chunk_base64s_.push_back(chunk_base64); + } +} + +bool +FileAudioChunks::InitializeAudio() +{ + std::cout << "[FileAudioChunks] Initializing file audio for: " << filepath_ << std::endl; + fs::path path(filepath_); + std::string extension = path.extension().string(); + + // File exists + if (!fs::exists(filepath_)) { + std::cerr << "[FileAudioChunks] Error: File does not exist, " << filepath_ << std::endl; + return false; + } + + // File is a WAV file + if (extension != ".wav") { + std::cerr << "[FileAudioChunks] Error: File is not a WAV file, " << filepath_ << std::endl; + return false; + } + + // Load WAV file using the existing WAV utilities + std::vector> all_wav; + LoadWavData(all_wav, filepath_); + + if (all_wav.empty()) { + std::cerr << "[FileAudioChunks] Error: Failed to load WAV file, " << filepath_ << std::endl; + return false; + } + + wav_data_ = all_wav[0]; // Use the first WAV file + + CalculateChunkSizeBytes(GetSampleRateHz()); + + return true; +} + +void +FileAudioChunks::ProcessAudioData() +{ + SplitIntoChunks(); +} + +// FileAudioChunks getter implementations +std::string +FileAudioChunks::GetFilepath() const +{ + return filepath_; +} + +int +FileAudioChunks::GetSampleRateHz() const +{ + return wav_data_->sample_rate; +} + +int +FileAudioChunks::GetNumChannels() const +{ + return wav_data_->channels; +} + +int +FileAudioChunks::GetBitDepth() const +{ + // Calculate bit depth from data size and sample rate + if (wav_data_->channels > 0 && wav_data_->sample_rate > 0) { + return (wav_data_->data.size() * 8) / (wav_data_->channels * wav_data_->sample_rate); + } + return 16; // Default to 16-bit +} + +double +FileAudioChunks::GetDurationSeconds() const +{ + if (wav_data_->sample_rate > 0 && wav_data_->channels > 0) { + return static_cast(wav_data_->data.size()) / + (wav_data_->sample_rate * wav_data_->channels * 2); // Assuming 16-bit + } + return 0.0; +} + +int +FileAudioChunks::GetNumSamples() const +{ + if (wav_data_->channels > 0) { + return wav_data_->data.size() / (wav_data_->channels * 2); // Assuming 16-bit + } + return 0; +} + +// ============================================================================ +// MicrophoneChunks derived class implementation +// ============================================================================ + +MicrophoneChunks::MicrophoneChunks( + const std::string& device_name, const int& chunk_size_ms, int sample_rate, int num_channels, + int bit_depth) + : AudioChunks(chunk_size_ms), device_name_(device_name), alsa_handle_(nullptr), + sample_rate_(sample_rate), num_channels_(num_channels), bit_depth_(bit_depth), + is_capturing_(false), request_exit_(false) +{ +} + +MicrophoneChunks::~MicrophoneChunks() +{ + StopCapture(); + CloseAudioDevice(); +} + +bool +MicrophoneChunks::OpenAudioDevice() +{ + int rc; + static snd_output_t* log; + + std::cout << "[MicrophoneChunks] Opening ALSA device: " << device_name_ << std::endl; + std::cout << "[MicrophoneChunks] Sample rate: " << sample_rate_ + << " Hz, Channels: " << num_channels_ << std::endl; + + if ((rc = snd_pcm_open(&alsa_handle_, device_name_.c_str(), SND_PCM_STREAM_CAPTURE, 0)) < 0) { + std::cerr << "[MicrophoneChunks] Unable to open PCM device for recording: " << snd_strerror(rc) + << std::endl; + return false; + } + + if ((rc = snd_output_stdio_attach(&log, stderr, 0)) < 0) { + std::cerr << "[MicrophoneChunks] Unable to attach log output: " << snd_strerror(rc) + << std::endl; + return false; + } + + // Set audio parameters + snd_pcm_format_t format = (bit_depth_ == 16) ? SND_PCM_FORMAT_S16_LE : SND_PCM_FORMAT_S32_LE; + unsigned int latency = 100000; // 100ms latency + + if ((rc = snd_pcm_set_params( + alsa_handle_, format, SND_PCM_ACCESS_RW_INTERLEAVED, num_channels_, sample_rate_, 1, + latency)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_set_params error: " << snd_strerror(rc) << std::endl; + return false; + } + + // Set software parameters for capture + snd_pcm_sw_params_t* sw_params = nullptr; + if ((rc = snd_pcm_sw_params_malloc(&sw_params)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_sw_params_malloc error: " << snd_strerror(rc) + << std::endl; + return false; + } + + if ((rc = snd_pcm_sw_params_current(alsa_handle_, sw_params)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_sw_params_current error: " << snd_strerror(rc) + << std::endl; + snd_pcm_sw_params_free(sw_params); + return false; + } + + if ((rc = snd_pcm_sw_params_set_start_threshold(alsa_handle_, sw_params, 1)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_sw_params_set_start_threshold failed: " + << snd_strerror(rc) << std::endl; + snd_pcm_sw_params_free(sw_params); + return false; + } + + if ((rc = snd_pcm_sw_params(alsa_handle_, sw_params)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_sw_params failed: " << snd_strerror(rc) << std::endl; + snd_pcm_sw_params_free(sw_params); + return false; + } + + snd_pcm_sw_params_free(sw_params); + + std::cout << "[MicrophoneChunks] Successfully opened ALSA device" << std::endl; + return true; +} + +void +MicrophoneChunks::CloseAudioDevice() +{ + if (alsa_handle_) { + snd_pcm_close(alsa_handle_); + alsa_handle_ = nullptr; + std::cout << "[MicrophoneChunks] Closed ALSA device" << std::endl; + } +} + +bool +MicrophoneChunks::InitializeAudio() +{ + std::cout << "[MicrophoneChunks] Initializing microphone audio for device: " << device_name_ + << std::endl; + + if (!OpenAudioDevice()) { + std::cerr << "[MicrophoneChunks] Error: Failed to open audio device" << std::endl; + return false; + } + + CalculateChunkSizeBytes(sample_rate_); + + return true; +} + +void +MicrophoneChunks::ProcessAudioData() +{ + // For microphone, we don't pre-process data - it comes in real-time + // This method is called during Init() but doesn't populate chunks initially + std::cout << "[MicrophoneChunks] Microphone initialized, ready for capture" << std::endl; +} + +bool +MicrophoneChunks::StartCapture() +{ + if (is_capturing_) { + std::cout << "[MicrophoneChunks] Already capturing audio" << std::endl; + return true; + } + + if (!initialized_) { + std::cerr << "[MicrophoneChunks] Error: Microphone not initialized" << std::endl; + return false; + } + + request_exit_ = false; + is_capturing_ = true; + + // Start capture thread + capture_thread_ = std::thread(&MicrophoneChunks::CaptureThreadMain, this); + + std::cout << "[MicrophoneChunks] Started audio capture" << std::endl; + return true; +} + +void +MicrophoneChunks::StopCapture() +{ + if (!is_capturing_) { + return; + } + + request_exit_ = true; + is_capturing_ = false; + + if (capture_thread_.joinable()) { + capture_thread_.join(); + } + + std::cout << "[MicrophoneChunks] Stopped audio capture" << std::endl; +} + +void +MicrophoneChunks::CaptureThreadMain() +{ + std::cout << "[MicrophoneChunks] Capture thread started" << std::endl; + + const size_t chunk_size = chunk_size_bytes_; + std::vector chunk(chunk_size); + + while (is_capturing_ && !request_exit_) { + // Read audio chunk from microphone + snd_pcm_sframes_t frames_read = + snd_pcm_readi(alsa_handle_, &chunk[0], chunk_size / sizeof(int16_t)); + + if (frames_read < 0) { + std::cerr << "[MicrophoneChunks] Read failed: " << snd_strerror(frames_read) << std::endl; + // Try to recover from error + if (snd_pcm_recover(alsa_handle_, frames_read, 0) < 0) { + std::cerr << "[MicrophoneChunks] Failed to recover from error" << std::endl; + break; + } + continue; + } + + if (frames_read > 0) { + // Convert frames to bytes + size_t bytes_read = frames_read * sizeof(int16_t); + + // Create chunk with actual data read + std::vector actual_chunk(chunk.begin(), chunk.begin() + bytes_read); + std::string chunk_base64 = EncodeBase64(actual_chunk); + + // Add to chunks with thread safety + { + std::lock_guard lock(chunks_mutex_); + chunk_base64s_.push_back(chunk_base64); + + // Keep only last 100 chunks to prevent memory issues + if (chunk_base64s_.size() > 100) { + chunk_base64s_.erase(chunk_base64s_.begin()); + } + } + + // Notify waiting threads + chunks_cv_.notify_all(); + + std::cout << "[MicrophoneChunks] Captured chunk " << chunk_base64s_.size() << " (" + << bytes_read << " bytes)" << std::endl; + } + } + + std::cout << "[MicrophoneChunks] Capture thread ended" << std::endl; +} + +// MicrophoneChunks getter implementations +std::string +MicrophoneChunks::GetDeviceName() const +{ + return device_name_; +} + +bool +MicrophoneChunks::IsCapturing() const +{ + return is_capturing_; +} + +int +MicrophoneChunks::GetSampleRateHz() const +{ + return sample_rate_; +} + +int +MicrophoneChunks::GetNumChannels() const +{ + return num_channels_; +} + +int +MicrophoneChunks::GetBitDepth() const +{ + return bit_depth_; +} + +double +MicrophoneChunks::GetDurationSeconds() const +{ + // For microphone, duration is ongoing - return 0 + return 0.0; +} + +int +MicrophoneChunks::GetNumSamples() const +{ + // For microphone, samples are ongoing - return 0 + return 0; +} + +std::string +MicrophoneChunks::GetLatestChunk() const +{ + std::lock_guard lock(chunks_mutex_); + if (chunk_base64s_.empty()) { + return ""; + } + return chunk_base64s_.back(); +} + +void +MicrophoneChunks::WaitForNewChunk() +{ + std::unique_lock lock(chunks_mutex_); + chunks_cv_.wait(lock, [this] { return !chunk_base64s_.empty(); }); +} + +} // namespace nvidia::riva::realtime diff --git a/riva/clients/realtime/audio_chunks.h b/riva/clients/realtime/audio_chunks.h new file mode 100644 index 0000000..fa15560 --- /dev/null +++ b/riva/clients/realtime/audio_chunks.h @@ -0,0 +1,166 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#ifndef AUDIO_CHUNKS_H +#define AUDIO_CHUNKS_H + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "riva/utils/wav/wav_data.h" +#include "riva/utils/wav/wav_reader.h" + +namespace fs = std::filesystem; + +namespace nvidia::riva::realtime { + +// Forward declarations - we'll include the actual headers in the .cpp file +void LoadWavData(std::vector>& all_wav, const std::string& filepath); + +// Base class for audio input +class AudioChunks { + protected: + bool initialized_ = false; + size_t chunk_size_ms_; + size_t chunk_size_bytes_; + std::vector chunk_base64s_; + + // Common methods for derived classes + void CalculateChunkSizeBytes(int sample_rate); + std::string EncodeBase64(const std::vector& data); + + // Virtual methods for derived classes to implement + virtual bool InitializeAudio() = 0; + virtual void ProcessAudioData() = 0; + + public: + AudioChunks(const int& chunk_size_ms); + virtual ~AudioChunks() = default; + + bool Init(); + + // Getters + size_t GetChunkSizeMs() const; + size_t GetChunkSizeBytes() const; + bool IsInitialized() const; + + // Audio properties (to be implemented by derived classes) + virtual int GetSampleRateHz() const = 0; + virtual int GetNumChannels() const = 0; + virtual int GetBitDepth() const = 0; + virtual double GetDurationSeconds() const = 0; + virtual int GetNumSamples() const = 0; + const std::vector& GetChunkBase64s() const; + + // Iterator support + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + using reverse_iterator = std::vector::reverse_iterator; + using const_reverse_iterator = std::vector::const_reverse_iterator; + + // Iterator methods + iterator begin() { return chunk_base64s_.begin(); } + const_iterator begin() const { return chunk_base64s_.begin(); } + iterator end() { return chunk_base64s_.end(); } + const_iterator end() const { return chunk_base64s_.end(); } + + // Reverse iterator methods + reverse_iterator rbegin() { return chunk_base64s_.rbegin(); } + const_reverse_iterator rbegin() const { return chunk_base64s_.rbegin(); } + reverse_iterator rend() { return chunk_base64s_.rend(); } + const_reverse_iterator rend() const { return chunk_base64s_.rend(); } + + // Const iterator methods + const_iterator cbegin() const { return chunk_base64s_.cbegin(); } + const_iterator cend() const { return chunk_base64s_.cend(); } + const_reverse_iterator crbegin() const { return chunk_base64s_.crbegin(); } + const_reverse_iterator crend() const { return chunk_base64s_.crend(); } + + // Size methods + size_t size() const { return chunk_base64s_.size(); } + bool empty() const { return chunk_base64s_.empty(); } +}; + +// Derived class for file-based audio input +class FileAudioChunks : public AudioChunks { + private: + std::string filepath_; + std::shared_ptr wav_data_; + + void SplitIntoChunks(); + bool InitializeAudio() override; + void ProcessAudioData() override; + + public: + FileAudioChunks(const std::string& filepath, const int& chunk_size_ms); + ~FileAudioChunks() = default; + + std::string GetFilepath() const; + int GetSampleRateHz() const override; + int GetNumChannels() const override; + int GetBitDepth() const override; + double GetDurationSeconds() const override; + int GetNumSamples() const override; +}; + +// Derived class for microphone input +class MicrophoneChunks : public AudioChunks { + private: + std::string device_name_; + snd_pcm_t* alsa_handle_; + std::thread capture_thread_; + std::atomic is_capturing_; + std::atomic request_exit_; + mutable std::mutex chunks_mutex_; // Make mutable for const member functions + std::condition_variable chunks_cv_; + + // Audio capture parameters + int sample_rate_; + int num_channels_; + int bit_depth_; + + // Capture thread function + void CaptureThreadMain(); + bool OpenAudioDevice(); + void CloseAudioDevice(); + bool InitializeAudio() override; + void ProcessAudioData() override; + + public: + MicrophoneChunks( + const std::string& device_name, const int& chunk_size_ms, int sample_rate = 16000, + int num_channels = 1, int bit_depth = 16); + ~MicrophoneChunks(); + + // Microphone-specific methods + bool StartCapture(); + void StopCapture(); + bool IsCapturing() const; + std::string GetDeviceName() const; + + // Audio properties + int GetSampleRateHz() const override; + int GetNumChannels() const override; + int GetBitDepth() const override; + double GetDurationSeconds() const override; + int GetNumSamples() const override; + + // Real-time chunk access + std::string GetLatestChunk() const; + void WaitForNewChunk(); +}; + +} // namespace nvidia::riva::realtime + +#endif // AUDIO_CHUNKS_H \ No newline at end of file diff --git a/riva/clients/realtime/base_client.cpp b/riva/clients/realtime/base_client.cpp new file mode 100644 index 0000000..bdeab1b --- /dev/null +++ b/riva/clients/realtime/base_client.cpp @@ -0,0 +1,226 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include "base_client.h" + +#include +#include + + +nvidia::riva::realtime::WebSocketClientBase::WebSocketClientBase(const std::string& uri) + : connected_(false), connectionClosedByServer_(false), connectionTimeoutMs_(std::size_t(5000)), + uri_(uri) +{ + // Set up logging - suppress verbose internal messages + wsClient_.set_access_channels(websocketpp::log::alevel::connect); + wsClient_.set_access_channels(websocketpp::log::alevel::disconnect); + wsClient_.set_access_channels(websocketpp::log::alevel::fail); + wsClient_.set_access_channels(websocketpp::log::alevel::app); + + // Initialize ASIO + wsClient_.init_asio(); + + // Set up handlers + wsClient_.set_open_handler( + std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnOpen, this, std::placeholders::_1)); + wsClient_.set_close_handler(std::bind( + &nvidia::riva::realtime::WebSocketClientBase::OnClose, this, std::placeholders::_1)); + wsClient_.set_fail_handler( + std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnFail, this, std::placeholders::_1)); + wsClient_.set_message_handler(std::bind( + &nvidia::riva::realtime::WebSocketClientBase::OnMessage, this, std::placeholders::_1, + std::placeholders::_2)); +} + +void +nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout( + const std::size_t connectionTimeoutMs) +{ + connectionTimeoutMs_ = connectionTimeoutMs; +} + +std::size_t +nvidia::riva::realtime::WebSocketClientBase::GetConnectionTimeout() +{ + return connectionTimeoutMs_; +} + +void +nvidia::riva::realtime::WebSocketClientBase::SetVerboseLogging(bool verbose) +{ + if (verbose) { + // Enable all logging channels + wsClient_.set_access_channels(websocketpp::log::alevel::all); + wsClient_.clear_access_channels(websocketpp::log::alevel::frame_payload); + } else { + // Minimal logging - only important events + wsClient_.clear_access_channels(websocketpp::log::alevel::all); + wsClient_.set_access_channels(websocketpp::log::alevel::connect); + wsClient_.set_access_channels(websocketpp::log::alevel::disconnect); + wsClient_.set_access_channels(websocketpp::log::alevel::fail); + wsClient_.set_access_channels(websocketpp::log::alevel::app); + } +} + +void +nvidia::riva::realtime::WebSocketClientBase::Connect(const std::string& uri) +{ + uri_ = uri; + websocketpp::lib::error_code ec; + + websocketpp_client::connection_ptr con = wsClient_.get_connection(uri, ec); + if (ec) { + std::cerr << "Could not create connection: " << ec.message() << std::endl; + return; + } + + wsClient_.connect(con); +} + +void +nvidia::riva::realtime::WebSocketClientBase::Run() +{ + wsClient_.run(); +} + +void +nvidia::riva::realtime::WebSocketClientBase::Send(const std::string& message) +{ + std::lock_guard lock(mutex_); + if (connected_) { + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, message, websocketpp::frame::opcode::text, ec); + if (ec) { + std::cerr << "Send failed: " << ec.message() << std::endl; + } + } +} + +void +nvidia::riva::realtime::WebSocketClientBase::Close() +{ + std::lock_guard lock(mutex_); + if (connected_) { + websocketpp::lib::error_code ec; + wsClient_.close(connectionHdl_, websocketpp::close::status::normal, "Client closing", ec); + } +} + +void +nvidia::riva::realtime::WebSocketClientBase::SendJsonMessage( + const std::string& type, const std::string& data) +{ + std::lock_guard lock(mutex_); + if (connected_) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value(type.c_str(), allocator), allocator); + if (!data.empty()) { + doc.AddMember("data", rapidjson::Value(data.c_str(), allocator), allocator); + } + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + std::cerr << "Send failed: " << ec.message() << std::endl; + } else { + std::cout << "Sent: " << buffer.GetString() << std::endl; + } + } +} + +void +nvidia::riva::realtime::WebSocketClientBase::OnOpen(websocketpp::connection_hdl hdl) +{ + std::lock_guard lock(mutex_); + connectionHdl_ = hdl; + connected_ = true; + + // Notify waiting threads that connection is established + { + std::lock_guard conn_lock(connectionMutex_); + connectionCv_.notify_one(); + } + + std::cout << "Connected to " << uri_ << std::endl; +} + +void +nvidia::riva::realtime::WebSocketClientBase::OnClose(websocketpp::connection_hdl hdl) +{ + (void)hdl; // Suppress unused parameter warning + std::lock_guard lock(mutex_); + connected_ = false; + + // Check if this was a server-initiated close + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + connectionCv_.notify_one(); + + std::cout << "Connection closed" << std::endl; +} + +void +nvidia::riva::realtime::WebSocketClientBase::OnFail(websocketpp::connection_hdl hdl) +{ + (void)hdl; // Suppress unused parameter warning + std::lock_guard lock(mutex_); + connected_ = false; + + // Mark as server-initiated failure + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + connectionCv_.notify_one(); + + std::cout << "************************ Connection failed" << std::endl; +} + +void +nvidia::riva::realtime::WebSocketClientBase::OnMessage( + websocketpp::connection_hdl hdl, message_ptr msg) +{ + (void)hdl; // Suppress unused parameter warning + HandleMessage(msg->get_payload()); +} + +bool +nvidia::riva::realtime::WebSocketClientBase::WaitForConnection() +{ + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for( + lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return connected_; }); +} + +bool +nvidia::riva::realtime::WebSocketClientBase::WaitForDisconnection() +{ + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for( + lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return !connected_; }); +} + +bool +nvidia::riva::realtime::WebSocketClientBase::WaitForServerClose() +{ + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for( + lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return connectionClosedByServer_; }); +} diff --git a/riva/clients/realtime/base_client.h b/riva/clients/realtime/base_client.h new file mode 100644 index 0000000..ec0ec8c --- /dev/null +++ b/riva/clients/realtime/base_client.h @@ -0,0 +1,87 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#ifndef BASE_REALTIME_CLIENT_H +#define BASE_REALTIME_CLIENT_H + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "audio_chunks.h" + +namespace nvidia::riva::realtime { +class WebSocketClientBase { + protected: + typedef websocketpp::client websocketpp_client; + typedef websocketpp::config::asio_client::message_type::ptr message_ptr; + + websocketpp_client wsClient_; + websocketpp::connection_hdl connectionHdl_; + + std::string uri_; + bool connected_; + std::mutex mutex_; + + // Connection state + bool connectionClosedByServer_; + std::condition_variable connectionCv_; + std::mutex connectionMutex_; + std::size_t connectionTimeoutMs_; + + // Protected access to websocket client for derived classes + websocketpp_client& GetWsClient() { return wsClient_; } + websocketpp::connection_hdl& GetConnection() { return connectionHdl_; } + std::mutex& GetConnectionMutex() { return connectionMutex_; } + + public: + WebSocketClientBase(const std::string& uri); + ~WebSocketClientBase() = default; + + // Connection timeout + void SetConnectionTimeout(const std::size_t connectionTimeoutMs); + std::size_t GetConnectionTimeout(); + + // Connection status + bool IsConnected() const { return connected_; } + bool IsConnectionClosedByServer() const { return connectionClosedByServer_; } + bool IsConnectionOpen() const { return connected_ && !connectionClosedByServer_; } + bool IsConnectionClosed() const { return !connected_ || connectionClosedByServer_; } + + // Control logging verbosity + void SetVerboseLogging(bool verbose); + + // Connection management + void Connect(const std::string& uri); + void Run(); + void Send(const std::string& message); + void Close(); + void SendJsonMessage(const std::string& type, const std::string& data = ""); + + // Connection waiting methods + bool WaitForConnection(); + bool WaitForDisconnection(); + bool WaitForServerClose(); + + // Event handlers + void OnOpen(websocketpp::connection_hdl hdl); + void OnClose(websocketpp::connection_hdl hdl); + void OnFail(websocketpp::connection_hdl hdl); + void OnMessage(websocketpp::connection_hdl hdl, message_ptr msg); + virtual void HandleMessage(const std::string& message) = 0; +}; +} // namespace nvidia::riva::realtime +#endif // BASE_REALTIME_CLIENT_H \ No newline at end of file diff --git a/riva/clients/realtime/realtime_client.cpp b/riva/clients/realtime/realtime_client.cpp new file mode 100644 index 0000000..4c583a0 --- /dev/null +++ b/riva/clients/realtime/realtime_client.cpp @@ -0,0 +1,720 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include "realtime_client.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "base_client.h" + +// Helper method for HTTP requests using raw sockets +std::string +nvidia::riva::realtime::RealtimeClient::MakeHttpRequest( + const std::string& host, int port, const std::string& path, const std::string& method, + const std::string& body) +{ + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + std::cerr << "Failed to create socket" << std::endl; + return ""; + } + + struct hostent* server = gethostbyname(host.c_str()); + if (server == nullptr) { + std::cerr << "Failed to resolve host: " << host << std::endl; + close(sock); + return ""; + } + + struct sockaddr_in serv_addr; + memset(&serv_addr, 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + memcpy(&serv_addr.sin_addr.s_addr, server->h_addr, server->h_length); + serv_addr.sin_port = htons(port); + + if (connect(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < 0) { + std::cerr << "Failed to connect to " << host << ":" << port << std::endl; + close(sock); + return ""; + } + + // Build HTTP request + std::ostringstream request; + request << method << " " << path << " HTTP/1.1\r\n"; + request << "Host: " << host << ":" << port << "\r\n"; + request << "Content-Type: application/json\r\n"; + request << "Content-Length: " << body.length() << "\r\n"; + request << "Connection: close\r\n"; + request << "\r\n"; + request << body; + + std::string request_str = request.str(); + + // Send request + if (send(sock, request_str.c_str(), request_str.length(), 0) < 0) { + std::cerr << "Failed to send HTTP request" << std::endl; + close(sock); + return ""; + } + + // Receive response + std::string response; + char buffer[4096]; + int bytes_received; + + while ((bytes_received = recv(sock, buffer, sizeof(buffer) - 1, 0)) > 0) { + buffer[bytes_received] = '\0'; + response += buffer; + } + + close(sock); + + // Extract JSON body from HTTP response + size_t body_start = response.find("\r\n\r\n"); + if (body_start != std::string::npos) { + return response.substr(body_start + 4); + } + + return response; +} + +bool +nvidia::riva::realtime::RealtimeClient::InitializeHttpSession() +{ + if (server_url_.empty()) { + std::cerr << "Server URL not set" << std::endl; + return false; + } + + // Parse server URL to extract host and port + std::string host = server_url_; + int port = 80; // Default HTTP port + + // Check if port is specified + size_t colon_pos = host.find(':'); + if (colon_pos != std::string::npos) { + port = std::stoi(host.substr(colon_pos + 1)); + host = host.substr(0, colon_pos); + } + + std::string path = "/v1/realtime/transcription_sessions"; + std::string response_body = MakeHttpRequest(host, port, path, "POST", "{}"); + + if (response_body.empty()) { + std::cerr << "HTTP request failed" << std::endl; + return false; + } + + try { + // Parse JSON response using rapidjson + rapidjson::Document session_data; + if (session_data.Parse(response_body.c_str()).HasParseError()) { + std::cerr << "Failed to parse JSON response" << std::endl; + return false; + } + + // Extract session ID + if (session_data.HasMember("id")) { + session_id_ = session_data["id"].GetString(); + } else { + std::cerr << "No session ID found in response" << std::endl; + return false; + } + + // Store server defaults but don't overwrite user-provided values + SessionConfig serverDefaults; + + if (session_data.HasMember("input_audio_transcription")) { + const auto& transcription = session_data["input_audio_transcription"]; + if (transcription.HasMember("language")) { + serverDefaults.language_code_ = transcription["language"].GetString(); + } + if (transcription.HasMember("model")) { + serverDefaults.model_name_ = transcription["model"].GetString(); + } + } + + if (session_data.HasMember("recognition_config")) { + const auto& recognition = session_data["recognition_config"]; + if (recognition.HasMember("max_alternatives")) { + serverDefaults.max_alternatives_ = recognition["max_alternatives"].GetInt(); + } + if (recognition.HasMember("enable_automatic_punctuation")) { + serverDefaults.automatic_punctuation_ = + recognition["enable_automatic_punctuation"].GetBool(); + } + if (recognition.HasMember("enable_word_time_offsets")) { + serverDefaults.word_time_offsets_ = recognition["enable_word_time_offsets"].GetBool(); + } + if (recognition.HasMember("enable_profanity_filter")) { + serverDefaults.profanity_filter_ = recognition["enable_profanity_filter"].GetBool(); + } + if (recognition.HasMember("enable_verbatim_transcripts")) { + serverDefaults.verbatim_transcripts_ = recognition["enable_verbatim_transcripts"].GetBool(); + } + } + + if (session_data.HasMember("speaker_diarization")) { + const auto& diarization = session_data["speaker_diarization"]; + if (diarization.HasMember("enable_speaker_diarization")) { + serverDefaults.speaker_diarization_ = diarization["enable_speaker_diarization"].GetBool(); + } + if (diarization.HasMember("max_speaker_count")) { + serverDefaults.diarization_max_speakers_ = diarization["max_speaker_count"].GetInt(); + } + } + + if (session_data.HasMember("endpointing_config")) { + const auto& endpointing = session_data["endpointing_config"]; + if (endpointing.HasMember("start_history")) { + serverDefaults.start_history_ = endpointing["start_history"].GetInt(); + } + if (endpointing.HasMember("start_threshold")) { + serverDefaults.start_threshold_ = endpointing["start_threshold"].GetDouble(); + } + if (endpointing.HasMember("stop_history")) { + serverDefaults.stop_history_ = endpointing["stop_history"].GetInt(); + } + if (endpointing.HasMember("stop_threshold")) { + serverDefaults.stop_threshold_ = endpointing["stop_threshold"].GetDouble(); + } + if (endpointing.HasMember("stop_history_eou")) { + serverDefaults.stop_history_eou_ = endpointing["stop_history_eou"].GetInt(); + } + if (endpointing.HasMember("stop_threshold_eou")) { + serverDefaults.stop_threshold_eou_ = endpointing["stop_threshold_eou"].GetDouble(); + } + } + + // Only use server defaults for values that haven't been set by user + if (sessionConfig_.language_code_.empty()) { + sessionConfig_.language_code_ = serverDefaults.language_code_; + } + if (sessionConfig_.model_name_.empty()) { + sessionConfig_.model_name_ = serverDefaults.model_name_; + } + if (sessionConfig_.max_alternatives_ == 0) { + sessionConfig_.max_alternatives_ = serverDefaults.max_alternatives_; + } + if (sessionConfig_.start_history_ == -1) { + sessionConfig_.start_history_ = serverDefaults.start_history_; + } + if (sessionConfig_.start_threshold_ == -1.0) { + sessionConfig_.start_threshold_ = serverDefaults.start_threshold_; + } + if (sessionConfig_.stop_history_ == -1) { + sessionConfig_.stop_history_ = serverDefaults.stop_history_; + } + if (sessionConfig_.stop_threshold_ == -1.0) { + sessionConfig_.stop_threshold_ = serverDefaults.stop_threshold_; + } + if (sessionConfig_.stop_history_eou_ == -1) { + sessionConfig_.stop_history_eou_ = serverDefaults.stop_history_eou_; + } + if (sessionConfig_.stop_threshold_eou_ == -1.0) { + sessionConfig_.stop_threshold_eou_ = serverDefaults.stop_threshold_eou_; + } + + // Convert rapidjson document to string for logging + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + session_data.Accept(writer); + + std::cout << "[" << objectName_ << "] Session initialized with defaults: " << buffer.GetString() + << std::endl; + return true; + } + catch (const std::exception& e) { + std::cerr << "Failed to parse session response: " << e.what() << std::endl; + return false; + } +} + +nvidia::riva::realtime::RealtimeClient::RealtimeClient( + const std::string& objectName, const std::shared_ptr audioChunksPtr, + nvidia::riva::utils::PerformanceStats& perfCounter) + : WebSocketClientBase("ws://127.0.0.1:9090/v1/realtime?intent=transcription"), + sessionInitialized_(false), sessionUpdated_(false), transcriptionCompleted_(false), + finalTranscriptionCount_(0), connectionTimeoutInMs_(std::size_t(10000)), + sessionInitTimeoutInMs_(std::size_t(10000)), sessionUpdateTimeoutInMs_(std::size_t(10000)), + transcriptionTimeoutInMs_(std::size_t(10000)), chunkDelayTimeInMs_(std::size_t(1000)), + objectName_(objectName), audioChunksPtr_(audioChunksPtr), perfCounter_(perfCounter) +{ + nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); +} + + +void +nvidia::riva::realtime::RealtimeClient::SetTimingConfig( + const std::size_t connectionTimeoutInMs, const std::size_t sessionInitTimeoutInMs, + const std::size_t sessionUpdateTimeoutInMs, const std::size_t transcriptionTimeoutInMs, + const std::size_t chunkDelayTimeInMs) +{ + connectionTimeoutInMs_ = connectionTimeoutInMs; + sessionInitTimeoutInMs_ = sessionInitTimeoutInMs; + sessionUpdateTimeoutInMs_ = sessionUpdateTimeoutInMs; + transcriptionTimeoutInMs_ = transcriptionTimeoutInMs; + chunkDelayTimeInMs_ = chunkDelayTimeInMs; + nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); +} + +void +nvidia::riva::realtime::RealtimeClient::Log(const std::string& message) +{ + std::cout << "[" << objectName_ << "]" << message << std::endl; +} + +bool +nvidia::riva::realtime::RealtimeClient::WaitForTranscriptionCompletion() +{ + std::unique_lock lock(transcriptionMutex_); + + // Reset completion flag + transcriptionCompleted_ = false; + + // Wait for completion event with timeout (increased from 3 seconds to 10 seconds) + bool completed = transcriptionCv_.wait_for( + lock, std::chrono::milliseconds(transcriptionTimeoutInMs_), + [this] { return transcriptionCompleted_; }); + + if (!completed) { + Log(" Timeout waiting for transcription completion after " + + std::to_string(transcriptionTimeoutInMs_) + " milliseconds"); + } else if (transcriptionCompleted_) { + // Close the connection + Close(); + } + + return completed; +} + +bool +nvidia::riva::realtime::RealtimeClient::WaitForSessionUpdate() +{ + std::unique_lock lock(sessionMutex_); + + if (sessionUpdated_) { + return true; + } + + // Wait for session update event with timeout + sessionUpdated_ = sessionCv_.wait_for( + lock, std::chrono::milliseconds(sessionUpdateTimeoutInMs_), + [this] { return sessionUpdated_; }); + + if (!sessionUpdated_) { + Log("Timeout waiting for session update after " + std::to_string(sessionUpdateTimeoutInMs_) + + " milliseconds"); + } + + return sessionUpdated_; +} + +// Send audio buffer append message (inspired by Python realtime.py) +void +nvidia::riva::realtime::RealtimeClient::SendAudioAppend(const std::string& audioBase64) +{ + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + doc.AddMember("type", rapidjson::Value("input_audio_buffer.append", allocator), allocator); + doc.AddMember("audio", rapidjson::Value(audioBase64.c_str(), allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio append failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } + } else { + Log("Skipping audio append - connection closed"); + } +} + +// Send audio buffer commit message (inspired by Python realtime.py) +void +nvidia::riva::realtime::RealtimeClient::SendAudioCommit() +{ + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value("input_audio_buffer.commit", allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio commit failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } + } else { + Log("Skipping audio commit - connection closed"); + } +} + +// Send audio buffer done message (inspired by Python realtime.py) +void +nvidia::riva::realtime::RealtimeClient::SendAudioDone() +{ + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value("input_audio_buffer.done", allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio done failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } else { + Log("Audio streaming completed"); + } + } else { + Log("Skipping audio done - connection closed"); + } +} + +// Modify the InitializeSession method to call HTTP initialization first +bool +nvidia::riva::realtime::RealtimeClient::InitializeSession() +{ + std::cout << "[" << objectName_ << "]" + << " Initializing session..." << std::endl; + + // Step 1: Initialize HTTP session + if (!InitializeHttpSession()) { + std::cerr << "Failed to initialize HTTP session" << std::endl; + return false; + } + + // Step 2: Wait for the initial connection and session creation + std::this_thread::sleep_for(std::chrono::milliseconds(3000)); + + // Step 3: Check if we're still connected + if (IsConnectionClosed()) { + std::cerr << "Connection lost during session initialization" << std::endl; + return false; + } + + // Step 4: Update session configuration + return UpdateSessionConfig(); +} + +bool +nvidia::riva::realtime::RealtimeClient::UpdateSessionConfig() +{ + int sampleRateHz = audioChunksPtr_->GetSampleRateHz(); + int numChannels = audioChunksPtr_->GetNumChannels(); + + std::cout << "Updating session configuration..." << std::endl; + std::cout << "Using WAV file parameters - Sample rate: " << sampleRateHz + << " Hz, Channels: " << numChannels << std::endl; + + // Create session configuration using sessionConfig_ (which now has defaults + user overrides) + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + // Create session config + rapidjson::Value session_config(rapidjson::kObjectType); + + // Add modalities + rapidjson::Value modalities(rapidjson::kArrayType); + modalities.PushBack(rapidjson::Value("text", allocator), allocator); + session_config.AddMember("modalities", modalities, allocator); + + // Add input audio format + session_config.AddMember("input_audio_format", rapidjson::Value("pcm16", allocator), allocator); + + // Input audio transcription config + rapidjson::Value transcription_config(rapidjson::kObjectType); + transcription_config.AddMember( + "language", rapidjson::Value(sessionConfig_.language_code_.c_str(), allocator), allocator); + transcription_config.AddMember( + "model", rapidjson::Value(sessionConfig_.model_name_.c_str(), allocator), allocator); + transcription_config.AddMember("prompt", rapidjson::Value(rapidjson::kNullType), allocator); + session_config.AddMember("input_audio_transcription", transcription_config, allocator); + + // Input audio params - use actual WAV file parameters + rapidjson::Value audio_params(rapidjson::kObjectType); + audio_params.AddMember("sample_rate_hz", sampleRateHz, allocator); + audio_params.AddMember("num_channels", numChannels, allocator); + session_config.AddMember("input_audio_params", audio_params, allocator); + + // Recognition config - use session configuration + rapidjson::Value recognition_config(rapidjson::kObjectType); + recognition_config.AddMember("max_alternatives", sessionConfig_.max_alternatives_, allocator); + recognition_config.AddMember( + "enable_automatic_punctuation", sessionConfig_.automatic_punctuation_, allocator); + recognition_config.AddMember( + "enable_word_time_offsets", sessionConfig_.word_time_offsets_, allocator); + recognition_config.AddMember( + "enable_profanity_filter", sessionConfig_.profanity_filter_, allocator); + recognition_config.AddMember( + "enable_verbatim_transcripts", sessionConfig_.verbatim_transcripts_, allocator); + recognition_config.AddMember( + "custom_configuration", + rapidjson::Value(sessionConfig_.custom_configuration_.c_str(), allocator), allocator); + session_config.AddMember("recognition_config", recognition_config, allocator); + + // Speaker diarization config + rapidjson::Value diarization_config(rapidjson::kObjectType); + diarization_config.AddMember( + "enable_speaker_diarization", sessionConfig_.speaker_diarization_, allocator); + diarization_config.AddMember( + "max_speaker_count", sessionConfig_.diarization_max_speakers_, allocator); + session_config.AddMember("speaker_diarization", diarization_config, allocator); + + // Word boosting config + rapidjson::Value word_boosting_config(rapidjson::kObjectType); + bool enable_word_boosting = !sessionConfig_.boosted_words_file_.empty(); + word_boosting_config.AddMember("enable_word_boosting", enable_word_boosting, allocator); + + if (enable_word_boosting) { + rapidjson::Value word_list(rapidjson::kArrayType); + std::ifstream file(sessionConfig_.boosted_words_file_); + std::string word; + while (std::getline(file, word)) { + if (!word.empty()) { + word_list.PushBack(rapidjson::Value(word.c_str(), allocator), allocator); + } + } + word_boosting_config.AddMember("word_boosting_list", word_list, allocator); + } else { + rapidjson::Value empty_list(rapidjson::kArrayType); + word_boosting_config.AddMember("word_boosting_list", empty_list, allocator); + } + session_config.AddMember("word_boosting", word_boosting_config, allocator); + + // Endpointing config + rapidjson::Value endpointing_config(rapidjson::kObjectType); + endpointing_config.AddMember("start_history", sessionConfig_.start_history_, allocator); + endpointing_config.AddMember("start_threshold", sessionConfig_.start_threshold_, allocator); + endpointing_config.AddMember("stop_history", sessionConfig_.stop_history_, allocator); + endpointing_config.AddMember("stop_threshold", sessionConfig_.stop_threshold_, allocator); + endpointing_config.AddMember("stop_history_eou", sessionConfig_.stop_history_eou_, allocator); + endpointing_config.AddMember("stop_threshold_eou", sessionConfig_.stop_threshold_eou_, allocator); + session_config.AddMember("endpointing_config", endpointing_config, allocator); + + // Create update request + rapidjson::Value update_request(rapidjson::kObjectType); + update_request.AddMember( + "type", rapidjson::Value("transcription_session.update", allocator), allocator); + update_request.AddMember("session", session_config, allocator); + + // Send the update request + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + update_request.Accept(writer); + + if (IsConnectionOpen()) { + std::lock_guard lock(connectionMutex_); + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + std::cout << "Session update failed: " << ec.message() << std::endl; + return false; + } else { + std::cout << "Session update request sent" << std::endl; + } + } + + + WaitForSessionUpdate(); + return true; +} + +// Send audio chunks +void +nvidia::riva::realtime::RealtimeClient::SendAudioChunks(const bool simulateRealtime) +{ + if (audioChunksPtr_ == nullptr) { + std::cerr << "Audio chunks pointer is null. Please call InitializeSession first." << std::endl; + return; + } + + if (!IsSessionInitialized()) { + std::cerr << "Session is not initialized. Please call InitializeSession first." << std::endl; + return; + } + + if (audioChunksPtr_->size() == 0) { + std::cerr << "No audio chunks to send. Please add audio chunks to the audio chunks pointer." + << std::endl; + return; + } + + std::cout << "Sending audio chunks with " << (simulateRealtime ? "real-time" : "burst") + << " timing..." << std::endl; + + // Track timing for accurate real-time simulation + auto stream_start_time = std::chrono::steady_clock::now(); + size_t chunk_index = 0; + + for (const std::string& chunk_base64 : *audioChunksPtr_) { + SendAudioAppend(chunk_base64); + SendAudioCommit(); + + if (simulateRealtime) { + // Calculate the exact time when this chunk should be sent + auto chunk_duration_ms = audioChunksPtr_->GetChunkSizeMs(); + auto expected_send_time = + stream_start_time + std::chrono::milliseconds((chunk_index + 1) * chunk_duration_ms); + + auto current_time = std::chrono::steady_clock::now(); + auto time_to_wait = expected_send_time - current_time; + + // Log timing information + // Timing calculations for real-time simulation (commented out as unused) + // auto elapsed_ms = std::chrono::duration(current_time - + // stream_start_time).count(); auto expected_ms = (chunk_index + 1) * chunk_duration_ms; auto + // drift_ms = elapsed_ms - expected_ms; + + // auto wait_ms = std::chrono::duration(time_to_wait).count(); + // std::cout << "[" << objectName_ << "] Chunk " << (chunk_index + 1) << "/" << + // audioChunksPtr_->size() + // << " - Elapsed: " << std::fixed << std::setprecision(1) << elapsed_ms << "ms" + // << " Expected: " << expected_ms << "ms" + // << " Drift: " << drift_ms << "ms"; + // << " Waiting: " << wait_ms << "ms" << std::endl; + + if (time_to_wait > std::chrono::milliseconds(0)) { + std::this_thread::sleep_for(time_to_wait); + } + } else { + // Burst mode - just log progress + if ((chunk_index + 1) % 10 == 0 || chunk_index == audioChunksPtr_->size() - 1) { + // zstd::cout << "[" << objectName_ << "] Sent " << (chunk_index + 1) << "/" << + // audioChunksPtr_->size() << " chunks" << std::endl; + } + } + + chunk_index++; + } + SendAudioDone(); +} + +void +nvidia::riva::realtime::RealtimeClient::HandleMessage(const std::string& message) +{ + bool is_last_result = false; + rapidjson::Document doc; + + if (doc.Parse(message.c_str()).HasParseError()) { + std::cerr << "Failed to parse JSON message" << std::endl; + return; + } + + std::string eventType = doc.HasMember("type") ? doc["type"].GetString() : ""; + + if (eventType == "conversation.created") { + std::cout << "Conversation created" << std::endl; + } else if (eventType == "transcription_session.updated") { + std::cout << "Session updated successfully" << std::endl; + sessionInitialized_ = true; + // Signal session update completion + { + std::lock_guard lock(sessionMutex_); + sessionUpdated_ = true; + } + sessionCv_.notify_one(); + } else if (eventType == "conversation.item.input_audio_transcription.delta") { + if (doc.HasMember("delta")) { + std::string delta = doc["delta"].GetString(); + + // std::cout << "Delta: " << delta << std::endl; + std::cout.flush(); // Ensure immediate output for streaming + } + } else if (eventType == "conversation.item.input_audio_transcription.completed") { + finalTranscriptionCount_++; + std::string transcript = doc.HasMember("transcript") ? doc["transcript"].GetString() : ""; + is_last_result = doc.HasMember("is_last_result") ? doc["is_last_result"].GetBool() : false; + + if (is_last_result) { + std::cout << "--------------------------------" << std::endl; + std::cout << "Final transcript: " << transcript << std::endl; + std::cout << "Final transcription count: " << finalTranscriptionCount_ << std::endl; + std::cout << "--------------------------------" << std::endl; + + // Transcription completed + std::lock_guard lock(transcriptionMutex_); + transcriptionCompleted_ = true; + transcriptionCv_.notify_one(); + } else { + std::cout << "Interim transcript: " << transcript << std::endl; + } + } else if (eventType.find("error") != std::string::npos) { + std::string errorMsg = "Unknown error"; + if (doc.HasMember("error") && doc["error"].HasMember("message")) { + errorMsg = doc["error"]["message"].GetString(); + } + std::cerr << "Error: " << errorMsg << std::endl; + } else { + // std::cout << "Received message type: " << event_type << std::endl; + } +} + +// Public wrapper methods for microphone audio streaming +void +nvidia::riva::realtime::RealtimeClient::SendAudioAppendPublic(const std::string& audioBase64) +{ + SendAudioAppend(audioBase64); +} + +void +nvidia::riva::realtime::RealtimeClient::SendAudioCommitPublic() +{ + SendAudioCommit(); +} + +void +nvidia::riva::realtime::RealtimeClient::SendAudioDonePublic() +{ + SendAudioDone(); +} \ No newline at end of file diff --git a/riva/clients/realtime/realtime_client.h b/riva/clients/realtime/realtime_client.h new file mode 100644 index 0000000..eaf6669 --- /dev/null +++ b/riva/clients/realtime/realtime_client.h @@ -0,0 +1,165 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#ifndef REALTIME_CLIENT_H +#define REALTIME_CLIENT_H + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "audio_chunks.h" +#include "base_client.h" +#include "riva/utils/stats_builder/stats_builder.h" + +// Add these includes for HTTP functionality +#include +#include +#include +#include +#include + +#include + +namespace nvidia::riva::realtime { +class SessionConfig { + public: + std::size_t connectionTimeoutInMs_; + std::size_t sessionInitTimeoutInMs_; + std::size_t sessionUpdateTimeoutInMs_; + std::size_t transcriptionTimeoutInMs_; + std::size_t chunkDelayTimeInMs_; + + // Add session configuration parameters + std::string language_code_; + std::string model_name_; + int max_alternatives_; + bool automatic_punctuation_; + bool word_time_offsets_; + bool profanity_filter_; + bool verbatim_transcripts_; + std::string boosted_words_file_; + double boosted_words_score_; + bool speaker_diarization_; + int diarization_max_speakers_; + int start_history_; + double start_threshold_; + int stop_history_; + double stop_threshold_; + int stop_history_eou_; + double stop_threshold_eou_; + std::string custom_configuration_; + + // Add HTTP session data + std::string session_id_; + std::string server_url_; +}; + +class RealtimeClient : public WebSocketClientBase { + private: + // Session tracking + bool sessionInitialized_; + bool sessionUpdated_; + std::condition_variable sessionCv_; + std::mutex sessionMutex_; + nvidia::riva::utils::PerformanceStats& perfCounter_; + + + // Event tracking + bool transcriptionCompleted_; + std::condition_variable transcriptionCv_; + std::mutex transcriptionMutex_; + + std::size_t finalTranscriptionCount_; + + // Configurable timing parameters (in milliseconds) + std::size_t connectionTimeoutInMs_; + std::size_t sessionInitTimeoutInMs_; + std::size_t sessionUpdateTimeoutInMs_; + std::size_t transcriptionTimeoutInMs_; + std::size_t chunkDelayTimeInMs_; + + std::string objectName_; + + // Audio processing + std::shared_ptr audioChunksPtr_; + + // Add session configuration + SessionConfig sessionConfig_; + + // Add HTTP session data + std::string session_id_; + std::string server_url_; + + // HTTP session initialization method + bool InitializeHttpSession(); + + // Helper method for HTTP requests + std::string MakeHttpRequest( + const std::string& host, int port, const std::string& path, const std::string& method, + const std::string& body); + + // Audio streaming methods + void SendAudioAppend(const std::string& audioBase64); + void SendAudioCommit(); + void SendAudioDone(); + + // Override base class methods + void HandleMessage(const std::string& message) override; + + public: + RealtimeClient( + const std::string& objectName, const std::shared_ptr audioChunksPtr, + nvidia::riva::utils::PerformanceStats& perfCounter); + ~RealtimeClient() = default; + + void Log(const std::string& message); + + // Timing configuration + void SetTimingConfig( + const std::size_t connectionTimeoutInMs, const std::size_t sessionInitTimeoutInMs, + const std::size_t sessionUpdateTimeoutInMs, const std::size_t transcriptionTimeoutInMs, + const std::size_t chunkDelayTimeInMs); + + // Session configuration + void SetSessionConfig(const SessionConfig& config) { sessionConfig_ = config; } + + // Session management methods + bool InitializeSession(); + bool UpdateSessionConfig(); + + bool IsSessionInitialized() const { return sessionInitialized_; } + + // Wait methods + bool WaitForSessionUpdate(); + bool WaitForTranscriptionCompletion(); + + // WAV file processing methods + void SendAudioChunks(const bool simulateRealtime = false); + + // Public audio streaming methods for microphone input + void SendAudioAppendPublic(const std::string& audioBase64); + void SendAudioCommitPublic(); + void SendAudioDonePublic(); + + // Add method to set server URL + void SetServerUrl(const std::string& server_url) { server_url_ = server_url; } + std::string GetSessionId() const { return session_id_; } +}; + +} // namespace nvidia::riva::realtime + +#endif // REALTIME_CLIENT_H \ No newline at end of file diff --git a/riva/clients/realtime/riva_realtime_asr_client.cc b/riva/clients/realtime/riva_realtime_asr_client.cc new file mode 100644 index 0000000..2c75caa --- /dev/null +++ b/riva/clients/realtime/riva_realtime_asr_client.cc @@ -0,0 +1,613 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "audio_chunks.h" +#include "riva/clients/realtime/realtime_client.h" +#include "riva/utils/stats_builder/stats_builder.h" + +// Add these includes for HTTP functionality +#include +#include +#include +#include +#include + +#include + +using namespace nvidia::riva::utils; +using namespace nvidia::riva::realtime; + +// Define command-line flags (matching streaming client) +DEFINE_string( + audio_file, "", "Folder that contains audio files to transcribe or individual audio file name"); +DEFINE_int32( + max_alternatives, 1, + "Maximum number of alternative transcripts to return (up to limit configured on server)"); +DEFINE_bool( + profanity_filter, false, + "Flag that controls if generated transcripts should be filtered for the profane words"); +DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated"); +DEFINE_bool(word_time_offsets, true, "Flag that controls if word time stamps are requested"); +DEFINE_bool( + simulate_realtime, false, "Flag that controls if audio files should be sent in realtime"); +DEFINE_string(audio_device, "", "Name of audio device to use"); +DEFINE_string( + riva_uri, "ws://127.0.0.1:9090/v1/realtime?intent=transcription", "URI to access riva-server"); +DEFINE_int32(num_iterations, 1, "Number of times to loop over audio files"); +DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in flight"); +DEFINE_int32(chunk_duration_ms, 100, "Chunk duration in milliseconds"); +DEFINE_bool(print_transcripts, true, "Print final transcripts"); +DEFINE_bool(interim_results, true, "Print intermediate transcripts"); +DEFINE_string( + output_filename, "final_transcripts.json", + "Filename of .json file containing output transcripts"); +DEFINE_string(model_name, "", "Name of the TRTIS model to use"); +DEFINE_string(language_code, "en-US", "Language code of the model to use"); +DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word."); +DEFINE_double(boosted_words_score, 10., "Score by which to boost the boosted words"); +DEFINE_bool( + verbatim_transcripts, true, + "True returns text exactly as it was said with no normalization. False applies text inverse " + "normalization"); +DEFINE_string(ssl_root_cert, "", "Path to SSL root certificates file"); +DEFINE_string(ssl_client_key, "", "Path to SSL client certificates key"); +DEFINE_string(ssl_client_cert, "", "Path to SSL client certificates file"); +DEFINE_bool( + use_ssl, false, + "Whether to use SSL credentials or not. If ssl_root_cert is specified, this is assumed to be " + "true"); +DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); +DEFINE_int32( + start_history, -1, "Value (in milliseconds) to detect and initiate start of speech utterance"); +DEFINE_double( + start_threshold, -1., + "Threshold value to determine at what percentage start of speech is initiated"); +DEFINE_int32(stop_history, -1, "Value (in milliseconds) to detect endpoint and reset decoder"); +DEFINE_double(stop_threshold, -1., "Threshold value to determine when endpoint detected"); +DEFINE_int32( + stop_history_eou, -1, + "Value (in milliseconds) to detect endpoint and generate an intermediate final transcript"); +DEFINE_double( + stop_threshold_eou, -1., + "Threshold value for likelihood of blanks before detecting end of utterance"); +DEFINE_string( + custom_configuration, "", + "Custom configurations to be sent to the server as key value pairs "); +DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); +DEFINE_int32( + diarization_max_speakers, 4, + "Max number of speakers to detect when performing speaker diarization. Default is 4 (Max)"); +DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); +DEFINE_uint64(max_grpc_message_size, 16777216, "Max GRPC message size"); + +// Additional realtime-specific flags +DEFINE_int32(connection_timeout_ms, 100000, "Connection timeout in milliseconds"); +DEFINE_int32(session_init_timeout_ms, 100000, "Session initialization timeout in milliseconds"); +DEFINE_int32(session_update_timeout_ms, 100000, "Session update timeout in milliseconds"); +DEFINE_int32(transcription_timeout_ms, 100000, "Transcription timeout in milliseconds"); +DEFINE_int32(chunk_delay_time_ms, 160, "Delay between audio chunks in milliseconds"); +DEFINE_bool(verbose_logging, false, "Enable verbose logging"); +DEFINE_bool(show_detailed_stats, true, "Show detailed statistics"); +DEFINE_bool(show_tabular_stats, true, "Show tabular statistics"); + +// Microphone configuration (hardcoded like ASR clients) +const int MIC_SAMPLE_RATE = 16000; // 16kHz +const int MIC_CHANNELS = 1; // Mono +const int MIC_BIT_DEPTH = 16; // 16-bit + +// Global client pointer for signal handling +std::vector g_clients; +std::mutex g_clients_mutex; + +// Global exit flag for microphone coordination (like ASR clients) +std::atomic g_request_exit(false); + +// Signal handler for graceful shutdown +void +signal_handler(int signal) +{ + std::cout << "\nReceived signal " << signal << ", shutting down gracefully..." << std::endl; + g_request_exit = true; + + for (auto client : g_clients) { + client->Close(); + } + exit(0); +} + +// Helper function to format throughput as 10.246e00 instead of 1.0246e+01 +std::string +format_throughput(double value) +{ + std::ostringstream oss; + oss << std::fixed << std::setprecision(3) << value << "e00"; + return oss.str(); +} + +// Helper function to create appropriate audio chunks based on input type +std::shared_ptr +CreateAudioChunks( + const std::string& audio_file, const std::string& audio_device, int chunk_duration_ms) +{ + if (!audio_device.empty()) { + // Create microphone-based audio chunks + std::cout << "Creating microphone audio chunks for device: " << audio_device << std::endl; + std::cout << "Sample rate: " << MIC_SAMPLE_RATE << " Hz, Channels: " << MIC_CHANNELS + << ", Bit depth: " << MIC_BIT_DEPTH << std::endl; + + auto mic_chunks = std::make_shared( + audio_device, chunk_duration_ms, MIC_SAMPLE_RATE, MIC_CHANNELS, MIC_BIT_DEPTH); + + if (!mic_chunks->Init()) { + std::cerr << "Failed to initialize microphone audio chunks" << std::endl; + return nullptr; + } + + return mic_chunks; + } else if (!audio_file.empty()) { + // Create file-based audio chunks + std::cout << "Creating file audio chunks for: " << audio_file << std::endl; + + auto file_chunks = std::make_shared(audio_file, chunk_duration_ms); + + if (!file_chunks->Init()) { + std::cerr << "Failed to initialize file audio chunks" << std::endl; + return nullptr; + } + + return file_chunks; + } + + std::cerr << "No audio source specified" << std::endl; + return nullptr; +} + +// Function to run the client example +void +client_runner( + const std::string& uri, + const std::shared_ptr& audio_chunks, + PerformanceStats& perfCounter, const std::size_t connectionTimeoutInMs, + const std::size_t sessionInitTimeoutInMs, const std::size_t sessionUpdateTimeoutInMs, + const std::size_t transcriptionTimeoutInMs, const std::size_t chunkDelayTimeInMs, + const bool simulateRealtime = false) +{ + nvidia::riva::realtime::RealtimeClient client( + perfCounter.GetObjectName(), audio_chunks, perfCounter); + + // Extract server URL from URI (remove ws:// and path) + std::string server_url = uri; + if (server_url.find("ws://") == 0) { + server_url = server_url.substr(5); // Remove "ws://" + } else if (server_url.find("wss://") == 0) { + server_url = server_url.substr(6); // Remove "wss://" + } + + // Remove path part (everything after first /) + size_t path_pos = server_url.find('/'); + if (path_pos != std::string::npos) { + server_url = server_url.substr(0, path_pos); + } + + client.SetServerUrl(server_url); + + // Set session configuration from command line flags (these will override defaults) + nvidia::riva::realtime::SessionConfig sessionConfig; + + // Only set values if they were provided by user (not default values) + if (!FLAGS_language_code.empty() && FLAGS_language_code != "en-US") { + sessionConfig.language_code_ = FLAGS_language_code; + } + if (!FLAGS_model_name.empty()) { + sessionConfig.model_name_ = FLAGS_model_name; + } + if (FLAGS_max_alternatives != 1) { + sessionConfig.max_alternatives_ = FLAGS_max_alternatives; + } + if (!FLAGS_automatic_punctuation) { // Default is true, so only override if false + sessionConfig.automatic_punctuation_ = FLAGS_automatic_punctuation; + } + if (!FLAGS_word_time_offsets) { // Default is true, so only override if false + sessionConfig.word_time_offsets_ = FLAGS_word_time_offsets; + } + if (FLAGS_profanity_filter) { // Default is false, so only override if true + sessionConfig.profanity_filter_ = FLAGS_profanity_filter; + } + if (!FLAGS_verbatim_transcripts) { // Default is true, so only override if false + sessionConfig.verbatim_transcripts_ = FLAGS_verbatim_transcripts; + } + if (!FLAGS_boosted_words_file.empty()) { + sessionConfig.boosted_words_file_ = FLAGS_boosted_words_file; + sessionConfig.boosted_words_score_ = FLAGS_boosted_words_score; + } + if (FLAGS_speaker_diarization) { // Default is false, so only override if true + sessionConfig.speaker_diarization_ = FLAGS_speaker_diarization; + sessionConfig.diarization_max_speakers_ = FLAGS_diarization_max_speakers; + } + if (FLAGS_start_history > 0) { + sessionConfig.start_history_ = FLAGS_start_history; + } + if (FLAGS_start_threshold > 0) { + sessionConfig.start_threshold_ = FLAGS_start_threshold; + } + if (FLAGS_stop_history > 0) { + sessionConfig.stop_history_ = FLAGS_stop_history; + } + if (FLAGS_stop_threshold > 0) { + sessionConfig.stop_threshold_ = FLAGS_stop_threshold; + } + if (FLAGS_stop_history_eou > 0) { + sessionConfig.stop_history_eou_ = FLAGS_stop_history_eou; + } + if (FLAGS_stop_threshold_eou > 0) { + sessionConfig.stop_threshold_eou_ = FLAGS_stop_threshold_eou; + } + if (!FLAGS_custom_configuration.empty()) { + sessionConfig.custom_configuration_ = FLAGS_custom_configuration; + } + + client.SetSessionConfig(sessionConfig); + client.SetVerboseLogging(FLAGS_verbose_logging); + client.SetTimingConfig( + connectionTimeoutInMs, sessionInitTimeoutInMs, sessionUpdateTimeoutInMs, + transcriptionTimeoutInMs, chunkDelayTimeInMs); + + // Step 1: Connect to the WebSocket server + client.Connect(uri); + + std::thread client_thread([&client]() { client.Run(); }); + + // Step 2: Wait for the connection to be established + if (!client.WaitForConnection()) { + std::cerr << "Failed to establish WebSocket connection" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + std::cout << "WebSocket connection established" << std::endl; + + // Step 3: Initialize the session + if (!client.InitializeSession()) { + std::cerr << "Failed to initialize session" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + std::cout << "Waiting for session update confirmation..." << std::endl; + + // Step 4: Wait for the session to be updated + if (!client.WaitForSessionUpdate()) { + std::cerr << "Session update timeout" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + // Step 5: Send the audio chunks with realistic timing + perfCounter.StartProcessingTimer(); + perfCounter.SetAudioDurationInSeconds(audio_chunks->GetDurationSeconds()); + + // For microphone input, we need to start capture before sending chunks + if (auto mic_chunks = std::dynamic_pointer_cast(audio_chunks)) { + std::cout << "Starting microphone capture..." << std::endl; + + // Ensure microphone is stopped on early exit + auto mic_cleanup = [mic_chunks]() { + if (mic_chunks->IsCapturing()) { + mic_chunks->StopCapture(); + std::cout << "Stopped microphone capture (cleanup)" << std::endl; + } + }; + + if (!mic_chunks->StartCapture()) { + std::cerr << "Failed to start microphone capture" << std::endl; + mic_cleanup(); + client.Close(); + client_thread.join(); + return; + } + + // For microphone: start continuous audio streaming in background thread + std::cout << "Starting continuous audio streaming..." << std::endl; + + // Start continuous audio streaming in a separate thread (like ASR clients) + std::thread audio_thread([&client, mic_chunks, simulateRealtime]() { + // Continuous streaming loop for microphone input + while (!g_request_exit && mic_chunks->IsCapturing()) { + // Get the latest audio chunk from microphone + std::string latest_chunk = mic_chunks->GetLatestChunk(); + if (!latest_chunk.empty()) { + // Send the audio chunk to the server + client.SendAudioAppendPublic(latest_chunk); + client.SendAudioCommitPublic(); + } + + // Small delay to prevent busy waiting + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Send audio done when streaming ends + if (!g_request_exit) { + client.SendAudioDonePublic(); + } + }); + + // Keep microphone running while waiting for transcription completion + // The microphone will continue capturing until transcription completes or exit is requested + std::cout << "Microphone is now active. Press Ctrl+C to stop." << std::endl; + + // For microphone input, wait for user interruption or transcription completion + // The audio thread will continue running until g_request_exit is set + while (!g_request_exit && mic_chunks->IsCapturing()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + // Wait for audio transmission to complete + audio_thread.join(); + + } else { + // For file-based audio, send chunks normally + client.SendAudioChunks(simulateRealtime); + } + + std::cout << "Waiting for transcription completion..." << std::endl; + + // Step 6: Wait for the transcription to be completed + if (client.WaitForTranscriptionCompletion()) { + std::cout << "Transcription completed successfully!" << std::endl; + perfCounter.EndProcessingTimer(); + perfCounter.SetSuccess(true); + } else { + std::cout << "Transcription did not complete within timeout" << std::endl; + perfCounter.EndProcessingTimer(); + } + + // Step 6.5: Stop microphone capture if it was used (after transcription completes) + if (auto mic_chunks = std::dynamic_pointer_cast(audio_chunks)) { + // Set exit flag to stop the audio streaming thread + g_request_exit = true; + + // Stop microphone capture + mic_chunks->StopCapture(); + std::cout << "Stopped microphone capture" << std::endl; + } + + // Step 7: Close the WebSocket connection + client.Close(); + client_thread.join(); + + { + std::lock_guard lock(g_clients_mutex); + g_clients.push_back(&client); + } + + // Step 8: Report the stats + perfCounter.ReportStats(); +} + +int +main(int argc, char* argv[]) +{ + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = 1; + + // Set up usage message + std::stringstream str_usage; + str_usage << "Usage: riva_realtime_asr_client " << std::endl; + str_usage << " --audio_file= " << std::endl; + str_usage << " --audio_device= " << std::endl; + str_usage << " --automatic_punctuation=" << std::endl; + str_usage << " --max_alternatives=" << std::endl; + str_usage << " --profanity_filter=" << std::endl; + str_usage << " --word_time_offsets=" << std::endl; + str_usage << " --riva_uri= " << std::endl; + str_usage << " --chunk_duration_ms= " << std::endl; + str_usage << " --interim_results= " << std::endl; + str_usage << " --simulate_realtime= " << std::endl; + str_usage << " --num_iterations= " << std::endl; + str_usage << " --num_parallel_requests= " << std::endl; + str_usage << " --print_transcripts= " << std::endl; + str_usage << " --output_filename=" << std::endl; + str_usage << " --verbatim_transcripts=" << std::endl; + str_usage << " --language_code=" << std::endl; + str_usage << " --boosted_words_file=" << std::endl; + str_usage << " --boosted_words_score=" << std::endl; + str_usage << " --ssl_root_cert=" << std::endl; + str_usage << " --ssl_client_key=" << std::endl; + str_usage << " --ssl_client_cert=" << std::endl; + str_usage << " --model_name=" << std::endl; + str_usage << " --metadata=" << std::endl; + str_usage << " --start_history=" << std::endl; + str_usage << " --start_threshold=" << std::endl; + str_usage << " --stop_history=" << std::endl; + str_usage << " --stop_history_eou=" << std::endl; + str_usage << " --stop_threshold=" << std::endl; + str_usage << " --stop_threshold_eou=" << std::endl; + str_usage << " --custom_configuration=" << std::endl; + str_usage << " --speaker_diarization=" << std::endl; + str_usage << " --diarization_max_speakers=" << std::endl; + str_usage << " --timeout_ms=" << std::endl; + str_usage << " --max_grpc_message_size=" << std::endl; + str_usage << " --connection_timeout_ms=" << std::endl; + str_usage << " --session_init_timeout_ms=" << std::endl; + str_usage << " --session_update_timeout_ms=" << std::endl; + str_usage << " --transcription_timeout_ms=" << std::endl; + str_usage << " --chunk_delay_time_ms=" << std::endl; + str_usage << " --verbose_logging=" << std::endl; + str_usage << " --show_detailed_stats=" << std::endl; + str_usage << " --show_tabular_stats=" << std::endl; + // Note: Microphone configuration is hardcoded (16kHz, mono, 16-bit) like ASR clients + + gflags::SetUsageMessage(str_usage.str()); + + if (argc < 2) { + std::cout << gflags::ProgramUsage(); + return 1; + } + + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (argc > 1) { + std::cout << gflags::ProgramUsage(); + return 1; + } + + // Validate arguments + if (FLAGS_max_alternatives < 1) { + std::cerr << "max_alternatives must be greater than or equal to 1." << std::endl; + return 1; + } + + if (FLAGS_num_iterations < 1) { + std::cerr << "num_iterations must be greater than 0" << std::endl; + return 1; + } + + if (FLAGS_num_parallel_requests < 1) { + std::cerr << "num_parallel_requests must be greater than 0" << std::endl; + return 1; + } + + // Check if audio file or device is specified + if (FLAGS_audio_file.empty() && FLAGS_audio_device.empty()) { + std::cerr << "Either --audio_file or --audio_device must be specified" << std::endl; + return 1; + } + + // Validate audio file exists if specified + if (!FLAGS_audio_file.empty() && !std::filesystem::exists(FLAGS_audio_file)) { + std::cerr << "Audio file does not exist: " << FLAGS_audio_file << std::endl; + return 1; + } + + // Validate microphone parameters (using hardcoded values like ASR clients) + if (!FLAGS_audio_device.empty()) { + // No validation needed since we use hardcoded values + // MIC_SAMPLE_RATE = 16000, MIC_CHANNELS = 1, MIC_BIT_DEPTH = 16 + + // For microphone input, enforce single request and iteration + if (FLAGS_num_parallel_requests != 1) { + std::cout << "Warning: num_parallel_requests set to 1 for microphone input" << std::endl; + FLAGS_num_parallel_requests = 1; + } + if (FLAGS_num_iterations != 1) { + std::cout << "Warning: num_iterations set to 1 for microphone input" << std::endl; + FLAGS_num_iterations = 1; + } + if (FLAGS_simulate_realtime) { + std::cout << "Warning: simulate_realtime set to false for microphone input" << std::endl; + FLAGS_simulate_realtime = false; + } + } + + // Use command-line arguments + const std::string uri = FLAGS_riva_uri; + const std::string audio_file_path = FLAGS_audio_file; + const std::string audio_device = FLAGS_audio_device; + const std::size_t num_iterations = FLAGS_num_iterations; + const std::size_t num_parallel_clients = FLAGS_num_parallel_requests; + const bool simulateRealtime = FLAGS_simulate_realtime; + + const std::size_t connectionTimeoutInMs = FLAGS_connection_timeout_ms; + const std::size_t sessionInitTimeoutInMs = FLAGS_session_init_timeout_ms; + const std::size_t sessionUpdateTimeoutInMs = FLAGS_session_update_timeout_ms; + const std::size_t transcriptionTimeoutInMs = FLAGS_transcription_timeout_ms; + const std::size_t chunkDelayTimeInMs = FLAGS_chunk_delay_time_ms; + const std::size_t chunk_duration_ms = FLAGS_chunk_duration_ms; + + // Create appropriate audio chunks based on input type + const auto audio_chunks = CreateAudioChunks(audio_file_path, audio_device, chunk_duration_ms); + if (!audio_chunks) { + std::cerr << "Failed to create audio chunks" << std::endl; + return 1; + } + + PerformanceStats overallPerf("Overall"); + + // For microphone input, duration is ongoing, so we'll use a reasonable estimate + double audio_duration = audio_chunks->GetDurationSeconds(); + if (audio_duration <= 0.0 && !audio_device.empty()) { + // Microphone input - use a reasonable duration estimate for stats + audio_duration = 60.0; // Assume 1 minute for microphone sessions + std::cout << "Using estimated duration of " << audio_duration << " seconds for microphone input" + << std::endl; + } + + overallPerf.SetAudioDurationInSeconds(audio_duration * num_iterations * num_parallel_clients); + + // Set up signal handlers for graceful shutdown (before starting async operations) + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + + // Create StatsBuilder for all clients + StatsBuilder statsBuilder("client", audio_duration, num_parallel_clients); + + // Run iterations asynchronously + std::vector> futures; + std::cout << "Starting " << num_parallel_clients << " async clients..." << std::endl; + + overallPerf.StartProcessingTimer(); + for (std::size_t N = 0; N < num_parallel_clients; ++N) { + // Launch each client asynchronously + futures.emplace_back(std::async(std::launch::async, [&, N]() { + std::cout << "Starting client " << (N + 1) << "/" << num_parallel_clients << std::endl; + + for (std::size_t M = 0; M < num_iterations; ++M) { + std::cout << " Running iteration " << (M + 1) << "/" << num_iterations << std::endl; + client_runner( + uri, audio_chunks, statsBuilder.GetPerformanceStats(N), connectionTimeoutInMs, + sessionInitTimeoutInMs, sessionUpdateTimeoutInMs, transcriptionTimeoutInMs, + chunkDelayTimeInMs, simulateRealtime); + } + + std::cout << "Completed client " << (N + 1) << "/" << num_parallel_clients << std::endl; + })); + } + + // Wait for all iterations to complete + std::cout << "Waiting for all iterations to complete..." << std::endl; + for (auto& future : futures) { + future.wait(); + } + std::cout << "All iterations completed!" << std::endl; + overallPerf.EndProcessingTimer(); + + // Conditional stats reporting based on flags + if (FLAGS_show_detailed_stats) { + statsBuilder.ReportDetailedStats(); + } + if (FLAGS_show_tabular_stats) { + statsBuilder.ReportTabularStats(); + } + + statsBuilder.ReportCumulativeStats(); + overallPerf.ReportStats(); + return 0; +} \ No newline at end of file diff --git a/riva/utils/stats_builder/BUILD b/riva/utils/stats_builder/BUILD new file mode 100644 index 0000000..5b550ee --- /dev/null +++ b/riva/utils/stats_builder/BUILD @@ -0,0 +1,7 @@ +cc_library( + name = "stats_builder_lib", + srcs = ["stats_builder.cpp"], + hdrs = ["stats_builder.h"], + includes = ["."], + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/riva/utils/stats_builder/stats_builder.cpp b/riva/utils/stats_builder/stats_builder.cpp new file mode 100644 index 0000000..6484d75 --- /dev/null +++ b/riva/utils/stats_builder/stats_builder.cpp @@ -0,0 +1,358 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include "stats_builder.h" + +#include +#include +#include +#include + +namespace nvidia::riva::utils { + +PerformanceStats::PerformanceStats(const std::string& objectName) + : success_(false), objectName_(objectName), + processing_start_time_(std::chrono::steady_clock::now()), + processing_end_time_(std::chrono::steady_clock::now()), audio_duration_seconds_(0.0) +{ +} + +StatsBuilder::StatsBuilder( + const std::string& objectName, double audio_duration_seconds, std::size_t num_iterations) + : audio_duration_seconds_(audio_duration_seconds), num_iterations_(num_iterations), + object_name_(objectName) +{ + // Pre-allocate the vector with the expected number of iterations + performanceStats_.reserve(num_iterations); + + // Create PerformanceStats objects for each iteration + for (std::size_t i = 0; i < num_iterations; ++i) { + std::string iteration_name = objectName + "-" + std::to_string(i); + performanceStats_.emplace_back(iteration_name); + // Set the audio duration for each performance stats object + performanceStats_.back().SetAudioDurationInSeconds(audio_duration_seconds); + } +} + +void +PerformanceStats::StartProcessingTimer() +{ + processing_start_time_ = std::chrono::steady_clock::now(); + // std::cout << "Starting processing timer: " << + // std::chrono::duration_cast(processing_start_time_.time_since_epoch()).count() + // << std::endl; +} + +void +PerformanceStats::EndProcessingTimer() +{ + processing_end_time_ = std::chrono::steady_clock::now(); + // std::cout << "Ending processing timer: " << + // std::chrono::duration_cast(processing_end_time_.time_since_epoch()).count() + // << std::endl; +} + +double +PerformanceStats::GetRuntimeInMs() const +{ + auto durationInMs = std::chrono::duration_cast( + processing_end_time_ - processing_start_time_); + return durationInMs.count(); +} + +double +PerformanceStats::GetRuntimeInSeconds() const +{ + return GetRuntimeInMs() / 1000.0; +} + +void +PerformanceStats::SetAudioDurationInSeconds(double audio_duration_seconds) +{ + audio_duration_seconds_ = audio_duration_seconds; +} + +double +PerformanceStats::GetThroughputRTFX() const +{ + double runtimeInMs = GetRuntimeInMs(); + if (runtimeInMs > 0.0 && audio_duration_seconds_ > 0.0) { + // RTFX = (Total Audio Processed in seconds) × 1000 ÷ (Total Runtime in milliseconds) + return (audio_duration_seconds_ * 1000.0) / runtimeInMs; + } + return 0.0; +} + +void +PerformanceStats::SetObjectName(const std::string& objectName) +{ + objectName_ = objectName; +} + +std::string +PerformanceStats::GetObjectName() const +{ + return objectName_; +} + +void +PerformanceStats::ReportStats() +{ + std::cout << "Object Name: " << GetObjectName() << std::endl; + std::cout << "Success: " << IsSuccess() << std::endl; + std::cout << "Audio Duration: " << audio_duration_seconds_ << " seconds" << std::endl; + std::cout << "Total Runtime: " << GetRuntimeInMs() << " ms (" << GetRuntimeInSeconds() + << " seconds)" << std::endl; + std::cout << "Throughput: " << GetThroughputRTFX() << " RTFX" << std::endl; +} + + +void +StatsBuilder::ReportCumulativeStats() +{ + std::cout << "Cumulative Stats" << std::endl; + std::cout << "=================" << std::endl; + for (auto performanceStats : performanceStats_) { + std::cout << "Object Name: " << performanceStats.GetObjectName() << std::endl; + std::cout << "Total Runtime: " << performanceStats.GetRuntimeInMs() << " ms (" + << performanceStats.GetRuntimeInSeconds() << " seconds)" << std::endl; + std::cout << "Throughput: " << performanceStats.GetThroughputRTFX() << " RTFX" << std::endl; + } +} + +// Helper function to calculate percentile +double +CalculatePercentile(const std::vector& values, double percentile) +{ + if (values.empty()) + return 0.0; + + std::vector sorted_values = values; + std::sort(sorted_values.begin(), sorted_values.end()); + + double index = (percentile / 100.0) * (sorted_values.size() - 1); + int lower_index = static_cast(index); + int upper_index = lower_index + 1; + + if (upper_index >= sorted_values.size()) { + return sorted_values[lower_index]; + } + + double weight = index - lower_index; + return sorted_values[lower_index] * (1 - weight) + sorted_values[upper_index] * weight; +} + +// Statistical methods for runtime +double +StatsBuilder::GetAverageRuntime() const +{ + if (performanceStats_.empty()) + return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetRuntimeInMs(); + } + return sum / performanceStats_.size(); +} + +double +StatsBuilder::GetP50Runtime() const +{ + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 50.0); +} + +double +StatsBuilder::GetP90Runtime() const +{ + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 90.0); +} + +double +StatsBuilder::GetP95Runtime() const +{ + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 95.0); +} + +double +StatsBuilder::GetP99Runtime() const +{ + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 99.0); +} + +double +StatsBuilder::GetMinRuntime() const +{ + if (performanceStats_.empty()) + return 0.0; + + double min_runtime = performanceStats_[0].GetRuntimeInMs(); + for (const auto& stats : performanceStats_) { + min_runtime = std::min(min_runtime, stats.GetRuntimeInMs()); + } + return min_runtime; +} + +double +StatsBuilder::GetMaxRuntime() const +{ + if (performanceStats_.empty()) + return 0.0; + + double max_runtime = performanceStats_[0].GetRuntimeInMs(); + for (const auto& stats : performanceStats_) { + max_runtime = std::max(max_runtime, stats.GetRuntimeInMs()); + } + return max_runtime; +} + +// Statistical methods for throughput +double +StatsBuilder::GetAverageThroughput() const +{ + if (performanceStats_.empty()) + return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetThroughputRTFX(); + } + return sum / performanceStats_.size(); +} + +// Statistical methods for throughput +double +StatsBuilder::GetCumulativeThroughput() const +{ + if (performanceStats_.empty()) + return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetThroughputRTFX(); + } + return sum; +} + +double +StatsBuilder::GetP90Throughput() const +{ + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 90.0); +} + +double +StatsBuilder::GetP95Throughput() const +{ + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 95.0); +} + +double +StatsBuilder::GetP99Throughput() const +{ + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 99.0); +} + +bool +StatsBuilder::AreAllIterationsSuccessful() const +{ + if (performanceStats_.empty()) + return false; + + for (const auto& stats : performanceStats_) { + if (!stats.IsSuccess()) { + return false; + } + } + return true; +} + +std::size_t +StatsBuilder::GetSuccessfulIterationsCount() const +{ + std::size_t success_count = 0; + for (const auto& stats : performanceStats_) { + if (stats.IsSuccess()) { + success_count++; + } + } + return success_count; +} + +std::size_t +StatsBuilder::GetFailedIterationsCount() const +{ + return performanceStats_.size() - GetSuccessfulIterationsCount(); +} + +double +StatsBuilder::GetSuccessRate() const +{ + if (performanceStats_.empty()) + return 0.0; + return static_cast(GetSuccessfulIterationsCount()) / performanceStats_.size() * 100.0; +} + +void +StatsBuilder::ReportDetailedStats() const +{ + std::cout << "\n=== DETAILED PERFORMANCE STATISTICS ===" << std::endl; + std::cout << "Audio Duration: " << audio_duration_seconds_ << " seconds" << std::endl; + std::cout << "Number of Iterations: " << num_iterations_ << std::endl; + std::cout << "Sample Count: " << performanceStats_.size() << std::endl; + + // Add success rate information + std::cout << "Success Rate: " << GetSuccessRate() << "% (" << GetSuccessfulIterationsCount() + << "/" << performanceStats_.size() << " iterations)" << std::endl; + std::cout << "All Iterations Successful: " << (AreAllIterationsSuccessful() ? "YES" : "NO") + << std::endl; + + std::cout << "\n--- RUNTIME STATISTICS (ms) ---" << std::endl; + std::cout << "Average: " << GetAverageRuntime() << " ms" << std::endl; + std::cout << "P50: " << GetP50Runtime() << " ms" << std::endl; + std::cout << "P90: " << GetP90Runtime() << " ms" << std::endl; + std::cout << "P95: " << GetP95Runtime() << " ms" << std::endl; + std::cout << "P99: " << GetP99Runtime() << " ms" << std::endl; + std::cout << "Min: " << GetMinRuntime() << " ms" << std::endl; + std::cout << "Max: " << GetMaxRuntime() << " ms" << std::endl; + + std::cout << "\n--- THROUGHPUT STATISTICS (RTFX) ---" << std::endl; + std::cout << "Average: " << GetAverageThroughput() << " RTFX" << std::endl; + std::cout << "Cumulative: " << GetCumulativeThroughput() << " RTFX" << std::endl; + std::cout << "P90: " << GetP90Throughput() << " RTFX" << std::endl; + std::cout << "P95: " << GetP95Throughput() << " RTFX" << std::endl; + std::cout << "P99: " << GetP99Throughput() << " RTFX" << std::endl; + + std::cout << "=====================================" << std::endl; +} + +} // namespace nvidia::riva::utils \ No newline at end of file diff --git a/riva/utils/stats_builder/stats_builder.h b/riva/utils/stats_builder/stats_builder.h new file mode 100644 index 0000000..2a1c273 --- /dev/null +++ b/riva/utils/stats_builder/stats_builder.h @@ -0,0 +1,142 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#ifndef STATS_BUILDER_H +#define STATS_BUILDER_H + +#include +#include +#include // Required for std::setw and std::fixed +#include +#include +#include +#include + +namespace nvidia::riva::utils { + +class PerformanceStats { + private: + bool success_; + std::string objectName_; + // Timing measurement + std::chrono::steady_clock::time_point processing_start_time_; + std::chrono::steady_clock::time_point processing_end_time_; + double audio_duration_seconds_; + + public: + PerformanceStats(const std::string& objectName); + ~PerformanceStats() = default; + + bool IsSuccess() const { return success_; } + void SetSuccess(bool success) { success_ = success; } + + void StartProcessingTimer(); + void EndProcessingTimer(); + std::chrono::steady_clock::time_point GetStartTime() const { return processing_start_time_; } + double GetRuntimeInMs() const; + double GetRuntimeInSeconds() const; + void SetAudioDurationInSeconds(double audio_duration_seconds); + double GetAudioDurationInSeconds() const { return audio_duration_seconds_; } + double GetThroughputRTFX() const; + + void SetObjectName(const std::string& objectName); + std::string GetObjectName() const; + + void ReportStats(); +}; + +class StatsBuilder { + private: + std::vector performanceStats_; + double audio_duration_seconds_; + std::size_t num_iterations_; + std::string object_name_; // Added to store the object name + + public: + StatsBuilder( + const std::string& objectName, double audio_duration_seconds, std::size_t num_iterations); + ~StatsBuilder() = default; + + void SetAudioDurationInSeconds(double audio_duration_seconds); + void SetNumIterations(std::size_t num_iterations); + void ReportCumulativeStats(); + PerformanceStats& GetPerformanceStats(std::size_t index) { return performanceStats_[index]; } + + // Statistical methods + double GetAverageRuntime() const; + double GetP50Runtime() const; + double GetP90Runtime() const; + double GetP95Runtime() const; + double GetP99Runtime() const; + double GetMinRuntime() const; + double GetMaxRuntime() const; + + // Throughput statistics + double GetAverageThroughput() const; + double GetCumulativeThroughput() const; + double GetP90Throughput() const; + double GetP95Throughput() const; + double GetP99Throughput() const; + + // Comprehensive reporting + void ReportDetailedStats() const; + + // Success checking methods + bool AreAllIterationsSuccessful() const; + std::size_t GetSuccessfulIterationsCount() const; + std::size_t GetFailedIterationsCount() const; + double GetSuccessRate() const; + + void ReportTabularStats() const + { + std::cout << "\n=== Tabular Performance Statistics ===" << std::endl; + std::cout << std::left << std::setw(15) << "Name" << std::setw(10) << "Success" << std::setw(12) + << "Runtime (s)" << std::setw(15) << "Audio (s)" << std::setw(15) << "Throughput" + << std::endl; + std::cout << std::string(75, '-') << std::endl; + + for (size_t i = 0; i < performanceStats_.size(); ++i) { + const auto& stats = performanceStats_[i]; + std::string name = object_name_ + "-" + std::to_string(i); + std::string success = stats.IsSuccess() ? "true" : "false"; + double runtime = stats.GetRuntimeInSeconds(); // Changed to GetRuntimeInSeconds + double audio_duration = audio_duration_seconds_; // Total audio processed + double throughput = stats.GetThroughputRTFX(); // Changed to GetThroughputRTFX + + std::cout << std::left << std::setw(15) << name << std::setw(10) << success << std::fixed + << std::setprecision(3) << std::setw(12) << runtime << std::setw(15) + << audio_duration << std::setw(15) << throughput << std::endl; + } + std::cout << std::string(60, '-') << std::endl; + + // Summary row + size_t success_count = 0; + double total_runtime = 0.0; + double total_audio_processed = + audio_duration_seconds_ * performanceStats_.size(); // Total audio across all iterations + double total_throughput = 0.0; + + for (const auto& stats : performanceStats_) { + if (stats.IsSuccess()) + success_count++; + total_runtime += stats.GetRuntimeInSeconds(); // Changed to GetRuntimeInSeconds + total_throughput += stats.GetThroughputRTFX(); // Changed to GetThroughputRTFX + } + + std::cout << std::left << std::setw(15) << "SUMMARY" << std::setw(10) + << (success_count == performanceStats_.size() + ? "ALL" + : std::to_string(success_count) + "/" + + std::to_string(performanceStats_.size())) + << std::fixed << std::setprecision(3) << std::setw(12) << total_runtime + << std::setw(15) << total_audio_processed << std::setw(15) << total_throughput + << std::endl; + std::cout << std::endl; + } +}; + +} // namespace nvidia::riva::utils + +#endif // STATS_BUILDER_H \ No newline at end of file diff --git a/third_party/BUILD.websocketpp b/third_party/BUILD.websocketpp new file mode 100644 index 0000000..5269bf1 --- /dev/null +++ b/third_party/BUILD.websocketpp @@ -0,0 +1,20 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: MIT +""" + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "websocketpp", + hdrs = glob([ + "websocketpp/*.hpp", + "websocketpp/**/*.hpp" + ]), + strip_include_prefix = ".", + deps = [ + "@com_google_absl//absl/strings", + ], +) \ No newline at end of file