diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index 938b47d..4815640 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -76,6 +76,7 @@ DEFINE_bool( "Whether to use SSL credentials or not. If ssl_cert is specified, " "this is assumed to be true"); DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); +DEFINE_int32(async_delay_ms, 0, "Delay to start parallel request asynchronously in milliseconds"); void signal_handler(int signal_num) @@ -118,6 +119,7 @@ main(int argc, char** argv) str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; str_usage << " --metadata=" << std::endl; + str_usage << " --async_delay_ms=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -164,7 +166,8 @@ main(int argc, char** argv) FLAGS_profanity_filter, FLAGS_word_time_offsets, FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_chunk_duration_ms, FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime, - FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score); + FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, + FLAGS_async_delay_ms); if (FLAGS_audio_file.size()) { return recognize_client.DoStreamingFromFile( @@ -205,4 +208,4 @@ main(int argc, char** argv) } return 0; -} +} \ No newline at end of file diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index ef5299a..3117827 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -57,7 +57,8 @@ StreamingRecognizeClient::StreamingRecognizeClient( bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts, int32_t chunk_duration_ms, bool interim_results, std::string output_filename, std::string model_name, bool simulate_realtime, - bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score) + bool verbatim_transcripts, const std::string& boosted_phrases_file, + float boosted_phrases_score, int32_t async_delay_ms) : print_latency_stats_(true), stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), @@ -66,7 +67,8 @@ StreamingRecognizeClient::StreamingRecognizeClient( print_transcripts_(print_transcripts), chunk_duration_ms_(chunk_duration_ms), interim_results_(interim_results), total_audio_processed_(0.), num_streams_started_(0), model_name_(model_name), simulate_realtime_(simulate_realtime), - verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score) + verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), + async_delay_ms_(async_delay_ms) { num_active_streams_.store(0); num_streams_finished_.store(0); @@ -218,17 +220,32 @@ StreamingRecognizeClient::DoStreamingFromFile( } } + ; + // Ensure there's also num_parallel_requests in flight uint32_t all_wav_i = 0; auto start_time = std::chrono::steady_clock::now(); + + uint32_t initial_streams = 0; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(1, async_delay_ms_); + while (true) { while (NumActiveStreams() < (uint32_t)num_parallel_requests && all_wav_i < all_wav_max) { + if(async_delay_ms_>0){ + if(initial_streams < (uint32_t)num_parallel_requests) { + std::this_thread::sleep_for(std::chrono::milliseconds(dis(gen))); + initial_streams++; + } + } + std::unique_ptr stream(new Stream(all_wav_repeated[all_wav_i], all_wav_i)); StartNewStream(std::move(stream)); ++all_wav_i; } - - // Break if no more tasks to add + + // Break if no more tasks to add if (NumStreamsFinished() == all_wav_max) { break; } @@ -444,4 +461,4 @@ StreamingRecognizeClient::PrintStats() << std::endl; return 1; } -} +} \ No newline at end of file diff --git a/riva/clients/asr/streaming_recognize_client.h b/riva/clients/asr/streaming_recognize_client.h index 14e7d17..b7c5d5a 100644 --- a/riva/clients/asr/streaming_recognize_client.h +++ b/riva/clients/asr/streaming_recognize_client.h @@ -25,7 +25,7 @@ #include #include #include - +#include #include "client_call.h" #include "riva/proto/riva_asr.grpc.pb.h" #include "riva/utils/thread_pool.h" @@ -47,7 +47,7 @@ class StreamingRecognizeClient { bool print_transcripts, int32_t chunk_duration_ms, bool interim_results, std::string output_filename, std::string model_name, bool simulate_realtime, bool verbatim_transcripts, const std::string& boosted_phrases_file, - float boosted_phrases_score); + float boosted_phrases_score, int32_t async_delay_ms); ~StreamingRecognizeClient(); @@ -114,4 +114,5 @@ class StreamingRecognizeClient { std::vector boosted_phrases_; float boosted_phrases_score_; -}; + int32_t async_delay_ms_; +}; \ No newline at end of file diff --git a/riva/clients/asr/streaming_recognize_client_test.cc b/riva/clients/asr/streaming_recognize_client_test.cc index 38aa185..51ecc3f 100644 --- a/riva/clients/asr/streaming_recognize_client_test.cc +++ b/riva/clients/asr/streaming_recognize_client_test.cc @@ -20,7 +20,7 @@ TEST(StreamingRecognizeClient, num_responses_requests) StreamingRecognizeClient recognize_client( grpc_channel, 1, "en-US", 1, false, false, false, false, false, 800, false, "dummy.txt", - "dummy", true, true, "", 10.); + "dummy", true, true, "", 10., 100); std::shared_ptr call = std::make_shared(1, true); uint32_t num_sends = 10;