diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index f61908b..0b416c5 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -44,6 +44,8 @@ DEFINE_int32( "Maximum number of alternative transcripts to return (up to limit configured on server)"); DEFINE_bool( profanity_filter, false, "Flag to control profanity filtering for the generated transcripts"); +DEFINE_bool( + remove_profane_words, false, "Flag that marks removal of profane words from the transcripts"); DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated"); DEFINE_bool(word_time_offsets, true, "Flag that controls if word time stamps are requested"); DEFINE_string(riva_uri, "localhost:50051", "URI to access riva-server"); @@ -72,14 +74,15 @@ class RecognizeClient { public: RecognizeClient( std::shared_ptr channel, const std::string& language_code, - int32_t max_alternatives, bool profanity_filter, bool word_time_offsets, - bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts, - std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, - bool speaker_diarization) + int32_t max_alternatives, bool profanity_filter, bool remove_profane_words, + bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel, + bool print_transcripts, std::string output_filename, std::string model_name, bool ctm, + bool verbatim_transcripts, const std::string& boosted_phrases_file, + float boosted_phrases_score, bool speaker_diarization) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), - word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), + remove_profane_words_(remove_profane_words), word_time_offsets_(word_time_offsets), + automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), speaker_diarization_(speaker_diarization), print_transcripts_(print_transcripts), done_sending_(false), num_requests_(0), num_responses_(0), num_failed_requests_(0), @@ -193,7 +196,13 @@ class RecognizeClient { config->set_encoding(wav->encoding); config->set_language_code(language_code_); config->set_max_alternatives(max_alternatives_); - config->set_profanity_filter(profanity_filter_); + config->set_profanity_filter(nr_asr::PROFANITY_OFF); + if (profanity_filter_) { + config->set_profanity_filter(nr_asr::PROFANITY_MASK); + } + if (remove_profane_words_) { + config->set_profanity_filter(nr_asr::PROFANITY_REMOVE); + } config->set_audio_channel_count(wav->channels); config->set_enable_word_time_offsets(word_time_offsets_); config->set_enable_automatic_punctuation(automatic_punctuation_); @@ -338,6 +347,7 @@ class RecognizeClient { std::string language_code_; int32_t max_alternatives_; bool profanity_filter_; + bool remove_profane_words_; int32_t channels_; bool word_time_offsets_; bool automatic_punctuation_; @@ -377,6 +387,7 @@ main(int argc, char** argv) str_usage << " --automatic_punctuation=" << std::endl; str_usage << " --max_alternatives=" << std::endl; str_usage << " --profanity_filter=" << std::endl; + str_usage << " --remove_profane_words=" << std::endl; str_usage << " --word_time_offsets=" << std::endl; str_usage << " --riva_uri= " << std::endl; str_usage << " --num_iterations= " << std::endl; @@ -433,7 +444,7 @@ main(int argc, char** argv) RecognizeClient recognize_client( grpc_channel, FLAGS_language_code, FLAGS_max_alternatives, FLAGS_profanity_filter, - FLAGS_word_time_offsets, FLAGS_automatic_punctuation, + FLAGS_remove_profane_words, FLAGS_word_time_offsets, FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_output_filename, FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization); diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index ea7beec..25e5ced 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -48,6 +48,9 @@ DEFINE_int32( DEFINE_bool( profanity_filter, false, "Flag that controls if generated transcripts should be filtered for the profane words"); +DEFINE_bool( + remove_profane_words, false, + "Flag that controls if the profane words should be removed from the transcript"); DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated"); DEFINE_bool(word_time_offsets, true, "Flag that controls if word time stamps are requested"); DEFINE_bool( @@ -103,6 +106,7 @@ main(int argc, char** argv) str_usage << " --automatic_punctuation=" << std::endl; str_usage << " --max_alternatives=" << std::endl; str_usage << " --profanity_filter=" << std::endl; + str_usage << " --remove_profane_words=" << std::endl; str_usage << " --word_time_offsets=" << std::endl; str_usage << " --riva_uri= " << std::endl; str_usage << " --chunk_duration_ms= " << std::endl; @@ -161,7 +165,8 @@ main(int argc, char** argv) StreamingRecognizeClient recognize_client( grpc_channel, FLAGS_num_parallel_requests, FLAGS_language_code, FLAGS_max_alternatives, - FLAGS_profanity_filter, FLAGS_word_time_offsets, FLAGS_automatic_punctuation, + FLAGS_profanity_filter, FLAGS_remove_profane_words, FLAGS_word_time_offsets, + FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_chunk_duration_ms, FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score); diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index ef5299a..d30a78f 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -54,14 +54,15 @@ MicrophoneThreadMain( StreamingRecognizeClient::StreamingRecognizeClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& language_code, int32_t max_alternatives, bool profanity_filter, - bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel, - bool print_transcripts, int32_t chunk_duration_ms, bool interim_results, - std::string output_filename, std::string model_name, bool simulate_realtime, - bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score) + bool remove_profane_words, bool word_time_offsets, bool automatic_punctuation, + bool separate_recognition_per_channel, bool print_transcripts, int32_t chunk_duration_ms, + bool interim_results, std::string output_filename, std::string model_name, + bool simulate_realtime, bool verbatim_transcripts, const std::string& boosted_phrases_file, + float boosted_phrases_score) : print_latency_stats_(true), stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), - profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), - automatic_punctuation_(automatic_punctuation), + profanity_filter_(profanity_filter), remove_profane_words_(remove_profane_words), + word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), print_transcripts_(print_transcripts), chunk_duration_ms_(chunk_duration_ms), interim_results_(interim_results), total_audio_processed_(0.), num_streams_started_(0), @@ -123,7 +124,13 @@ StreamingRecognizeClient::GenerateRequests(std::shared_ptr call) config->set_language_code(language_code_); config->set_encoding(call->stream->wav->encoding); config->set_max_alternatives(max_alternatives_); - config->set_profanity_filter(profanity_filter_); + config->set_profanity_filter(nr_asr::PROFANITY_OFF); + if (profanity_filter_) { + config->set_profanity_filter(nr_asr::PROFANITY_MASK); + } + if (remove_profane_words_) { + config->set_profanity_filter(nr_asr::PROFANITY_REMOVE); + } config->set_audio_channel_count(call->stream->wav->channels); config->set_enable_word_time_offsets(word_time_offsets_); config->set_enable_automatic_punctuation(automatic_punctuation_); @@ -376,7 +383,13 @@ StreamingRecognizeClient::DoStreamingFromMicrophone( config->set_language_code(language_code_); config->set_encoding(encoding); config->set_max_alternatives(max_alternatives_); - config->set_profanity_filter(profanity_filter_); + config->set_profanity_filter(nr_asr::PROFANITY_OFF); + if (profanity_filter_) { + config->set_profanity_filter(nr_asr::PROFANITY_MASK); + } + if (remove_profane_words_) { + config->set_profanity_filter(nr_asr::PROFANITY_REMOVE); + } config->set_audio_channel_count(channels); config->set_enable_word_time_offsets(word_time_offsets_); config->set_enable_automatic_punctuation(automatic_punctuation_); diff --git a/riva/clients/asr/streaming_recognize_client.h b/riva/clients/asr/streaming_recognize_client.h index 403cf25..7bcdb51 100644 --- a/riva/clients/asr/streaming_recognize_client.h +++ b/riva/clients/asr/streaming_recognize_client.h @@ -43,10 +43,10 @@ class StreamingRecognizeClient { StreamingRecognizeClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& language_code, int32_t max_alternatives, bool profanity_filter, - bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel, - bool print_transcripts, int32_t chunk_duration_ms, bool interim_results, - std::string output_filename, std::string model_name, bool simulate_realtime, - bool verbatim_transcripts, const std::string& boosted_phrases_file, + bool remove_profane_words, bool word_time_offsets, bool automatic_punctuation, + bool separate_recognition_per_channel, bool print_transcripts, int32_t chunk_duration_ms, + bool interim_results, std::string output_filename, std::string model_name, + bool simulate_realtime, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score); ~StreamingRecognizeClient(); @@ -87,6 +87,7 @@ class StreamingRecognizeClient { std::string language_code_; int32_t max_alternatives_; bool profanity_filter_; + bool remove_profane_words_; int32_t channels_; bool word_time_offsets_; bool automatic_punctuation_; diff --git a/riva/clients/asr/streaming_recognize_client_test.cc b/riva/clients/asr/streaming_recognize_client_test.cc index 38aa185..40522ee 100644 --- a/riva/clients/asr/streaming_recognize_client_test.cc +++ b/riva/clients/asr/streaming_recognize_client_test.cc @@ -19,8 +19,8 @@ TEST(StreamingRecognizeClient, num_responses_requests) auto current_time = std::chrono::steady_clock::now(); StreamingRecognizeClient recognize_client( - grpc_channel, 1, "en-US", 1, false, false, false, false, false, 800, false, "dummy.txt", - "dummy", true, true, "", 10.); + grpc_channel, 1, "en-US", 1, false, false, false, false, false, false, 800, false, + "dummy.txt", "dummy", true, true, "", 10.); std::shared_ptr call = std::make_shared(1, true); uint32_t num_sends = 10; diff --git a/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc b/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc index c63ad7e..9be6a28 100644 --- a/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc +++ b/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc @@ -44,6 +44,8 @@ DEFINE_string( DEFINE_bool( profanity_filter, false, "Flag that controls if generated transcripts should be filtered for the profane words"); +DEFINE_bool( + remove_profane_words, false, "Flag that controls if profane word need to be removed as well"); DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated"); DEFINE_bool( simulate_realtime, false, "Flag that controls if audio files should be sent in realtime"); @@ -97,6 +99,7 @@ main(int argc, char** argv) str_usage << " --audio_device= " << std::endl; str_usage << " --automatic_punctuation=" << std::endl; str_usage << " --profanity_filter=" << std::endl; + str_usage << " --remove_profane_words=" << std::endl; str_usage << " --riva_uri= " << std::endl; str_usage << " --chunk_duration_ms= " << std::endl; str_usage << " --simulate_realtime= " << std::endl; @@ -159,7 +162,8 @@ main(int argc, char** argv) StreamingS2SClient recognize_client( grpc_channel, FLAGS_num_parallel_requests, FLAGS_source_language_code, - FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_automatic_punctuation, + FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_remove_profane_words, + FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_chunk_duration_ms, FLAGS_simulate_realtime, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_tts_encoding, FLAGS_tts_audio_file, FLAGS_tts_sample_rate, FLAGS_tts_voice_name); diff --git a/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc b/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc index aa32a40..f614fd2 100644 --- a/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc +++ b/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc @@ -44,6 +44,8 @@ DEFINE_string( DEFINE_bool( profanity_filter, false, "Flag that controls if generated transcripts should be filtered for the profane words"); +DEFINE_bool( + remove_profane_words, false, "Flag that controls if profane word need to be removed as well"); DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated"); DEFINE_bool( simulate_realtime, false, "Flag that controls if audio files should be sent in realtime"); @@ -93,6 +95,7 @@ main(int argc, char** argv) str_usage << " --audio_device= " << std::endl; str_usage << " --automatic_punctuation=" << std::endl; str_usage << " --profanity_filter=" << std::endl; + str_usage << " --remove_profane_words=" << std::endl; str_usage << " --riva_uri= " << std::endl; str_usage << " --chunk_duration_ms= " << std::endl; str_usage << " --simulate_realtime= " << std::endl; @@ -146,7 +149,8 @@ main(int argc, char** argv) StreamingS2TClient recognize_client( grpc_channel, FLAGS_num_parallel_requests, FLAGS_source_language_code, - FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_automatic_punctuation, + FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_remove_profane_words, + FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_chunk_duration_ms, FLAGS_simulate_realtime, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_nmt_text_file); diff --git a/riva/clients/nmt/streaming_s2s_client.cc b/riva/clients/nmt/streaming_s2s_client.cc index 8bde539..16ab705 100644 --- a/riva/clients/nmt/streaming_s2s_client.cc +++ b/riva/clients/nmt/streaming_s2s_client.cc @@ -3,7 +3,6 @@ * SPDX-License-Identifier: MIT */ - #include "streaming_s2s_client.h" #include "riva/utils/opus/opus_client_decoder.h" @@ -54,14 +53,14 @@ MicrophoneThreadMain( StreamingS2SClient::StreamingS2SClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& source_language_code, const std::string& target_language_code, - bool profanity_filter, bool automatic_punctuation, bool separate_recognition_per_channel, - int32_t chunk_duration_ms, bool simulate_realtime, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, + bool profanity_filter, bool remove_profane_words, bool automatic_punctuation, + bool separate_recognition_per_channel, int32_t chunk_duration_ms, bool simulate_realtime, + bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, const std::string& tts_encoding, const std::string& tts_audio_file, int tts_sample_rate, const std::string& tts_voice_name) - : print_latency_stats_(true), stub_(nr_nmt::RivaTranslation::NewStub(channel)), - source_language_code_(source_language_code), target_language_code_(target_language_code), - profanity_filter_(profanity_filter), automatic_punctuation_(automatic_punctuation), + : stub_(nr_nmt::RivaTranslation::NewStub(channel)), source_language_code_(source_language_code), + target_language_code_(target_language_code), profanity_filter_(profanity_filter), + remove_profane_words_(remove_profane_words), automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), chunk_duration_ms_(chunk_duration_ms), total_audio_processed_(0.), num_streams_started_(0), simulate_realtime_(simulate_realtime), verbatim_transcripts_(verbatim_transcripts), @@ -138,7 +137,14 @@ StreamingS2SClient::GenerateRequests(std::shared_ptr call) config->set_language_code(source_language_code_); config->set_encoding(call->stream->wav->encoding); config->set_max_alternatives(1); - config->set_profanity_filter(profanity_filter_); + config->set_profanity_filter(nr_asr::PROFANITY_OFF); + if (profanity_filter_) { + config->set_profanity_filter(nr_asr::PROFANITY_MASK); + } + if (remove_profane_words_) { + config->set_profanity_filter(nr_asr::PROFANITY_REMOVE); + } + config->set_audio_channel_count(call->stream->wav->channels); config->set_enable_word_time_offsets(false); config->set_enable_automatic_punctuation(automatic_punctuation_); @@ -185,6 +191,8 @@ StreamingS2SClient::GenerateRequests(std::shared_ptr call) } } call->send_times.push_back(std::chrono::steady_clock::now()); + // std::time_t end_time = std::chrono::system_clock::to_time_t(call->send_times.back()); + // std::cout << &call << "Send time" << std::ctime(&end_time); call->streamer->Write(request); // Set write done to true so next call will lead to WritesDone @@ -198,7 +206,6 @@ StreamingS2SClient::GenerateRequests(std::shared_ptr call) std::lock_guard lock(latencies_mutex_); total_audio_processed_ += audio_processed; } - num_active_streams_--; } int @@ -266,23 +273,14 @@ void StreamingS2SClient::PostProcessResults(std::shared_ptr call, bool audio_device) { std::lock_guard lock(latencies_mutex_); - // it is possible we get one response more than the number of requests sent - // in the case where files are perfect multiple of chunk size - if (call->recv_times.size() != call->send_times.size() && - call->recv_times.size() != call->send_times.size() + 1) { - print_latency_stats_ = false; - } else { - for (uint32_t time_cnt = 0; time_cnt < call->send_times.size(); ++time_cnt) { - double lat = std::chrono::duration( - call->recv_times[time_cnt] - call->send_times[time_cnt]) - .count(); - if (call->recv_final_flags[time_cnt]) { - final_latencies_.push_back(lat); - } else { - int_latencies_.push_back(lat); - } - latencies_.push_back(lat); - } + // the latency for the s2s would be for an individual file as the difference between the last + // chunk sent to the first chunk of audio recieved. + if (simulate_realtime_) { + double lat = + std::chrono::duration(call->recv_times[0] - call->send_times.back()) + .count(); + std::cout << "Latency:" << lat << std::endl; + latencies_.push_back(lat); } } @@ -340,8 +338,11 @@ StreamingS2SClient::ReceiveResponses(std::shared_ptr call, bool a if (!status.ok()) { // Report the RPC failure. std::cerr << status.error_message() << std::endl; + } else { + PostProcessResults(call, audio_device); } - + // A stream would be marked as complete when both ASR and TTS are complete + num_active_streams_--; num_streams_finished_++; } @@ -398,7 +399,13 @@ StreamingS2SClient::DoStreamingFromMicrophone(const std::string& audio_device, b config->set_language_code(source_language_code_); config->set_encoding(encoding); config->set_max_alternatives(1); - config->set_profanity_filter(profanity_filter_); + config->set_profanity_filter(nr_asr::PROFANITY_OFF); + if (profanity_filter_) { + config->set_profanity_filter(nr_asr::PROFANITY_MASK); + } + if (remove_profane_words_) { + config->set_profanity_filter(nr_asr::PROFANITY_REMOVE); + } config->set_audio_channel_count(channels); config->set_enable_word_time_offsets(false); config->set_enable_automatic_punctuation(automatic_punctuation_); @@ -448,18 +455,13 @@ StreamingS2SClient::PrintLatencies(std::vector& latencies, const std::st int StreamingS2SClient::PrintStats() { - if (print_latency_stats_ && simulate_realtime_) { + if (simulate_realtime_) { PrintLatencies(latencies_, "Latencies"); - PrintLatencies(int_latencies_, "Intermediate latencies"); - PrintLatencies(final_latencies_, "Final latencies"); return 0; } else { - std::cout - << "Not printing latency statistics because the client is run without the " - "--simulate_realtime option and/or the number of requests sent is not equal to " - "number of requests received. To get latency statistics, run with --simulate_realtime " - "and set the --chunk_duration_ms to be the same as the server chunk duration" - << std::endl; + std::cout << "To get latency statistics, run with --simulate_realtime " + "and set the --chunk_duration_ms to be the same as the server chunk duration" + << std::endl; return 1; } } diff --git a/riva/clients/nmt/streaming_s2s_client.h b/riva/clients/nmt/streaming_s2s_client.h index 637d952..fb5d7dc 100644 --- a/riva/clients/nmt/streaming_s2s_client.h +++ b/riva/clients/nmt/streaming_s2s_client.h @@ -50,11 +50,11 @@ class StreamingS2SClient { StreamingS2SClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& source_language_code, const std::string& target_language_code_, - bool profanity_filter, bool automatic_punctuation, bool separate_recognition_per_channel, - int32_t chunk_duration_ms, bool simulate_realtime, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, - const std::string& tts_encoding, const std::string& tts_audio_file, int tts_sample_rate, - const std::string& tts_voice_name); + bool profanity_filter, bool remove_profane_words, bool automatic_punctuation, + bool separate_recognition_per_channel, int32_t chunk_duration_ms, bool simulate_realtime, + bool verbatim_transcripts, const std::string& boosted_phrases_file, + float boosted_phrases_score, const std::string& tts_encoding, + const std::string& tts_audio_file, int tts_sample_rate, const std::string& tts_voice_name); ~StreamingS2SClient(); @@ -82,13 +82,11 @@ class StreamingS2SClient { std::mutex latencies_mutex_; - bool print_latency_stats_; - private: // Out of the passed in Channel comes the stub, stored here, our view of the // server's exposed services. std::unique_ptr stub_; - std::vector int_latencies_, final_latencies_, latencies_; + std::vector latencies_; std::string tts_encoding_; std::string tts_audio_file_; std::string tts_voice_name_; @@ -97,6 +95,7 @@ class StreamingS2SClient { int tts_sample_rate_; bool profanity_filter_; + bool remove_profane_words_; int32_t channels_; bool automatic_punctuation_; bool separate_recognition_per_channel_; diff --git a/riva/clients/nmt/streaming_s2t_client.cc b/riva/clients/nmt/streaming_s2t_client.cc index b7f8be1..3afa396 100644 --- a/riva/clients/nmt/streaming_s2t_client.cc +++ b/riva/clients/nmt/streaming_s2t_client.cc @@ -53,13 +53,13 @@ MicrophoneThreadMain( StreamingS2TClient::StreamingS2TClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& source_language_code, const std::string& target_language_code, - bool profanity_filter, bool automatic_punctuation, bool separate_recognition_per_channel, - int32_t chunk_duration_ms, bool simulate_realtime, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, + bool profanity_filter, bool remove_profane_words, bool automatic_punctuation, + bool separate_recognition_per_channel, int32_t chunk_duration_ms, bool simulate_realtime, + bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, const std::string& nmt_text_file) - : print_latency_stats_(true), stub_(nr_nmt::RivaTranslation::NewStub(channel)), - source_language_code_(source_language_code), target_language_code_(target_language_code), - profanity_filter_(profanity_filter), automatic_punctuation_(automatic_punctuation), + : stub_(nr_nmt::RivaTranslation::NewStub(channel)), source_language_code_(source_language_code), + target_language_code_(target_language_code), profanity_filter_(profanity_filter), + remove_profane_words_(remove_profane_words), automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), chunk_duration_ms_(chunk_duration_ms), total_audio_processed_(0.), num_streams_started_(0), simulate_realtime_(simulate_realtime), verbatim_transcripts_(verbatim_transcripts), @@ -119,7 +119,14 @@ StreamingS2TClient::GenerateRequests(std::shared_ptr call) config->set_language_code(source_language_code_); config->set_encoding(call->stream->wav->encoding); config->set_max_alternatives(1); - config->set_profanity_filter(profanity_filter_); + config->set_profanity_filter(nr_asr::PROFANITY_OFF); + if (profanity_filter_) { + config->set_profanity_filter(nr_asr::PROFANITY_MASK); + } + if (remove_profane_words_) { + config->set_profanity_filter(nr_asr::PROFANITY_REMOVE); + } + config->set_audio_channel_count(call->stream->wav->channels); config->set_enable_word_time_offsets(false); config->set_enable_automatic_punctuation(automatic_punctuation_); @@ -247,23 +254,13 @@ void StreamingS2TClient::PostProcessResults(std::shared_ptr call, bool audio_device) { std::lock_guard lock(latencies_mutex_); - // it is possible we get one response more than the number of requests sent - // in the case where files are perfect multiple of chunk size - if (call->recv_times.size() != call->send_times.size() && - call->recv_times.size() != call->send_times.size() + 1) { - print_latency_stats_ = false; - } else { - for (uint32_t time_cnt = 0; time_cnt < call->send_times.size(); ++time_cnt) { - double lat = std::chrono::duration( - call->recv_times[time_cnt] - call->send_times[time_cnt]) - .count(); - if (call->recv_final_flags[time_cnt]) { - final_latencies_.push_back(lat); - } else { - int_latencies_.push_back(lat); - } - latencies_.push_back(lat); - } + + if (simulate_realtime_) { + double lat = + std::chrono::duration(call->recv_times[0] - call->send_times.back()) + .count(); + std::cout << "Latency:" << lat << std::endl; + latencies_.push_back(lat); } } @@ -299,6 +296,8 @@ StreamingS2TClient::ReceiveResponses(std::shared_ptr call, bool a if (!status.ok()) { // Report the RPC failure. std::cerr << status.error_message() << std::endl; + } else { + PostProcessResults(call, audio_device); } num_streams_finished_++; @@ -341,7 +340,13 @@ StreamingS2TClient::DoStreamingFromMicrophone(const std::string& audio_device, b config->set_language_code(source_language_code_); config->set_encoding(encoding); config->set_max_alternatives(1); - config->set_profanity_filter(profanity_filter_); + config->set_profanity_filter(nr_asr::PROFANITY_OFF); + if (profanity_filter_) { + config->set_profanity_filter(nr_asr::PROFANITY_MASK); + } + if (remove_profane_words_) { + config->set_profanity_filter(nr_asr::PROFANITY_REMOVE); + } config->set_audio_channel_count(channels); config->set_enable_word_time_offsets(false); config->set_enable_automatic_punctuation(automatic_punctuation_); @@ -391,18 +396,13 @@ StreamingS2TClient::PrintLatencies(std::vector& latencies, const std::st int StreamingS2TClient::PrintStats() { - if (print_latency_stats_ && simulate_realtime_) { + if (simulate_realtime_) { PrintLatencies(latencies_, "Latencies"); - PrintLatencies(int_latencies_, "Intermediate latencies"); - PrintLatencies(final_latencies_, "Final latencies"); return 0; } else { - std::cout - << "Not printing latency statistics because the client is run without the " - "--simulate_realtime option and/or the number of requests sent is not equal to " - "number of requests received. To get latency statistics, run with --simulate_realtime " - "and set the --chunk_duration_ms to be the same as the server chunk duration" - << std::endl; + std::cout << "To get latency statistics, run with --simulate_realtime " + "and set the --chunk_duration_ms to be the same as the server chunk duration" + << std::endl; return 1; } } diff --git a/riva/clients/nmt/streaming_s2t_client.h b/riva/clients/nmt/streaming_s2t_client.h index 5caccad..08df08d 100644 --- a/riva/clients/nmt/streaming_s2t_client.h +++ b/riva/clients/nmt/streaming_s2t_client.h @@ -49,10 +49,10 @@ class StreamingS2TClient { StreamingS2TClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& source_language_code, const std::string& target_language_code, - bool profanity_filter, bool automatic_punctuation, bool separate_recognition_per_channel, - int32_t chunk_duration_ms, bool simulate_realtime, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, - const std::string& nmt_text_file); + bool profanity_filter, bool remove_profane_words, bool automatic_punctuation, + bool separate_recognition_per_channel, int32_t chunk_duration_ms, bool simulate_realtime, + bool verbatim_transcripts, const std::string& boosted_phrases_file, + float boosted_phrases_score, const std::string& nmt_text_file); ~StreamingS2TClient(); @@ -81,17 +81,16 @@ class StreamingS2TClient { std::mutex latencies_mutex_; - bool print_latency_stats_; - private: // Out of the passed in Channel comes the stub, stored here, our view of the // server's exposed services. std::unique_ptr stub_; - std::vector int_latencies_, final_latencies_, latencies_; + std::vector latencies_; std::string source_language_code_; std::string target_language_code_; bool profanity_filter_; + bool remove_profane_words_; int32_t channels_; bool automatic_punctuation_; bool separate_recognition_per_channel_;