Skip to content

perf: Add delay "async_delay_ms" within the Cpp client to trigger parallel requests asynchronously #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions riva/clients/asr/riva_streaming_asr_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ DEFINE_bool(
"Whether to use SSL credentials or not. If ssl_cert is specified, "
"this is assumed to be true");
DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server");
DEFINE_int32(async_delay_ms, 0, "Delay to start parallel request asynchronously in milliseconds");

void
signal_handler(int signal_num)
Expand Down Expand Up @@ -118,6 +119,7 @@ main(int argc, char** argv)
str_usage << " --boosted_words_score=<float>" << std::endl;
str_usage << " --ssl_cert=<filename>" << std::endl;
str_usage << " --metadata=<key,value,...>" << std::endl;
str_usage << " --async_delay_ms=<integer>" << std::endl;
gflags::SetUsageMessage(str_usage.str());
gflags::SetVersionString(::riva::utils::kBuildScmRevision);

Expand Down Expand Up @@ -164,7 +166,8 @@ main(int argc, char** argv)
FLAGS_profanity_filter, FLAGS_word_time_offsets, FLAGS_automatic_punctuation,
/* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_chunk_duration_ms,
FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime,
FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score);
FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score,
FLAGS_async_delay_ms);

if (FLAGS_audio_file.size()) {
return recognize_client.DoStreamingFromFile(
Expand Down Expand Up @@ -205,4 +208,4 @@ main(int argc, char** argv)
}

return 0;
}
}
27 changes: 22 additions & 5 deletions riva/clients/asr/streaming_recognize_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ StreamingRecognizeClient::StreamingRecognizeClient(
bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel,
bool print_transcripts, int32_t chunk_duration_ms, bool interim_results,
std::string output_filename, std::string model_name, bool simulate_realtime,
bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score)
bool verbatim_transcripts, const std::string& boosted_phrases_file,
float boosted_phrases_score, int32_t async_delay_ms)
: print_latency_stats_(true), stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)),
language_code_(language_code), max_alternatives_(max_alternatives),
profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets),
Expand All @@ -66,7 +67,8 @@ StreamingRecognizeClient::StreamingRecognizeClient(
print_transcripts_(print_transcripts), chunk_duration_ms_(chunk_duration_ms),
interim_results_(interim_results), total_audio_processed_(0.), num_streams_started_(0),
model_name_(model_name), simulate_realtime_(simulate_realtime),
verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score)
verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score),
async_delay_ms_(async_delay_ms)
{
num_active_streams_.store(0);
num_streams_finished_.store(0);
Expand Down Expand Up @@ -218,17 +220,32 @@ StreamingRecognizeClient::DoStreamingFromFile(
}
}

;

// Ensure there's also num_parallel_requests in flight
uint32_t all_wav_i = 0;
auto start_time = std::chrono::steady_clock::now();

uint32_t initial_streams = 0;
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(1, async_delay_ms_);

while (true) {
while (NumActiveStreams() < (uint32_t)num_parallel_requests && all_wav_i < all_wav_max) {
if(async_delay_ms_>0){
if(initial_streams < (uint32_t)num_parallel_requests) {
std::this_thread::sleep_for(std::chrono::milliseconds(dis(gen)));
initial_streams++;
}
}

std::unique_ptr<Stream> stream(new Stream(all_wav_repeated[all_wav_i], all_wav_i));
StartNewStream(std::move(stream));
++all_wav_i;
}

// Break if no more tasks to add
// Break if no more tasks to add
if (NumStreamsFinished() == all_wav_max) {
break;
}
Expand Down Expand Up @@ -444,4 +461,4 @@ StreamingRecognizeClient::PrintStats()
<< std::endl;
return 1;
}
}
}
7 changes: 4 additions & 3 deletions riva/clients/asr/streaming_recognize_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include <sstream>
#include <string>
#include <thread>

#include <random>
#include "client_call.h"
#include "riva/proto/riva_asr.grpc.pb.h"
#include "riva/utils/thread_pool.h"
Expand All @@ -47,7 +47,7 @@ class StreamingRecognizeClient {
bool print_transcripts, int32_t chunk_duration_ms, bool interim_results,
std::string output_filename, std::string model_name, bool simulate_realtime,
bool verbatim_transcripts, const std::string& boosted_phrases_file,
float boosted_phrases_score);
float boosted_phrases_score, int32_t async_delay_ms);

~StreamingRecognizeClient();

Expand Down Expand Up @@ -114,4 +114,5 @@ class StreamingRecognizeClient {

std::vector<std::string> boosted_phrases_;
float boosted_phrases_score_;
};
int32_t async_delay_ms_;
};
2 changes: 1 addition & 1 deletion riva/clients/asr/streaming_recognize_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ TEST(StreamingRecognizeClient, num_responses_requests)

StreamingRecognizeClient recognize_client(
grpc_channel, 1, "en-US", 1, false, false, false, false, false, 800, false, "dummy.txt",
"dummy", true, true, "", 10.);
"dummy", true, true, "", 10., 100);

std::shared_ptr<ClientCall> call = std::make_shared<ClientCall>(1, true);
uint32_t num_sends = 10;
Expand Down