Skip to content

Update clients to use the profanity settings #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: release/2.13.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions riva/clients/asr/riva_asr_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ DEFINE_int32(
"Maximum number of alternative transcripts to return (up to limit configured on server)");
DEFINE_bool(
profanity_filter, false, "Flag to control profanity filtering for the generated transcripts");
DEFINE_bool(
remove_profane_words, false, "Flag that marks removal of profane words from the transcripts");
DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated");
DEFINE_bool(word_time_offsets, true, "Flag that controls if word time stamps are requested");
DEFINE_string(riva_uri, "localhost:50051", "URI to access riva-server");
Expand Down Expand Up @@ -72,14 +74,15 @@ class RecognizeClient {
public:
RecognizeClient(
std::shared_ptr<grpc::Channel> channel, const std::string& language_code,
int32_t max_alternatives, bool profanity_filter, bool word_time_offsets,
bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts,
std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts,
const std::string& boosted_phrases_file, float boosted_phrases_score,
bool speaker_diarization)
int32_t max_alternatives, bool profanity_filter, bool remove_profane_words,
bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel,
bool print_transcripts, std::string output_filename, std::string model_name, bool ctm,
bool verbatim_transcripts, const std::string& boosted_phrases_file,
float boosted_phrases_score, bool speaker_diarization)
: stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code),
max_alternatives_(max_alternatives), profanity_filter_(profanity_filter),
word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation),
remove_profane_words_(remove_profane_words), word_time_offsets_(word_time_offsets),
automatic_punctuation_(automatic_punctuation),
separate_recognition_per_channel_(separate_recognition_per_channel),
speaker_diarization_(speaker_diarization), print_transcripts_(print_transcripts),
done_sending_(false), num_requests_(0), num_responses_(0), num_failed_requests_(0),
Expand Down Expand Up @@ -193,7 +196,13 @@ class RecognizeClient {
config->set_encoding(wav->encoding);
config->set_language_code(language_code_);
config->set_max_alternatives(max_alternatives_);
config->set_profanity_filter(profanity_filter_);
config->set_profanity_filter(nr_asr::PROFANITY_OFF);
if (profanity_filter_) {
config->set_profanity_filter(nr_asr::PROFANITY_MASK);
}
if (remove_profane_words_) {
config->set_profanity_filter(nr_asr::PROFANITY_REMOVE);
}
config->set_audio_channel_count(wav->channels);
config->set_enable_word_time_offsets(word_time_offsets_);
config->set_enable_automatic_punctuation(automatic_punctuation_);
Expand Down Expand Up @@ -338,6 +347,7 @@ class RecognizeClient {
std::string language_code_;
int32_t max_alternatives_;
bool profanity_filter_;
bool remove_profane_words_;
int32_t channels_;
bool word_time_offsets_;
bool automatic_punctuation_;
Expand Down Expand Up @@ -377,6 +387,7 @@ main(int argc, char** argv)
str_usage << " --automatic_punctuation=<true|false>" << std::endl;
str_usage << " --max_alternatives=<integer>" << std::endl;
str_usage << " --profanity_filter=<true|false>" << std::endl;
str_usage << " --remove_profane_words=<true|false>" << std::endl;
str_usage << " --word_time_offsets=<true|false>" << std::endl;
str_usage << " --riva_uri=<server_name:port> " << std::endl;
str_usage << " --num_iterations=<integer> " << std::endl;
Expand Down Expand Up @@ -433,7 +444,7 @@ main(int argc, char** argv)

RecognizeClient recognize_client(
grpc_channel, FLAGS_language_code, FLAGS_max_alternatives, FLAGS_profanity_filter,
FLAGS_word_time_offsets, FLAGS_automatic_punctuation,
FLAGS_remove_profane_words, FLAGS_word_time_offsets, FLAGS_automatic_punctuation,
/* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_output_filename,
FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file,
(float)FLAGS_boosted_words_score, FLAGS_speaker_diarization);
Expand Down
7 changes: 6 additions & 1 deletion riva/clients/asr/riva_streaming_asr_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ DEFINE_int32(
DEFINE_bool(
profanity_filter, false,
"Flag that controls if generated transcripts should be filtered for the profane words");
DEFINE_bool(
remove_profane_words, false,
"Flag that controls if the profane words should be removed from the transcript");
DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated");
DEFINE_bool(word_time_offsets, true, "Flag that controls if word time stamps are requested");
DEFINE_bool(
Expand Down Expand Up @@ -103,6 +106,7 @@ main(int argc, char** argv)
str_usage << " --automatic_punctuation=<true|false>" << std::endl;
str_usage << " --max_alternatives=<integer>" << std::endl;
str_usage << " --profanity_filter=<true|false>" << std::endl;
str_usage << " --remove_profane_words=<true|false>" << std::endl;
str_usage << " --word_time_offsets=<true|false>" << std::endl;
str_usage << " --riva_uri=<server_name:port> " << std::endl;
str_usage << " --chunk_duration_ms=<integer> " << std::endl;
Expand Down Expand Up @@ -161,7 +165,8 @@ main(int argc, char** argv)

StreamingRecognizeClient recognize_client(
grpc_channel, FLAGS_num_parallel_requests, FLAGS_language_code, FLAGS_max_alternatives,
FLAGS_profanity_filter, FLAGS_word_time_offsets, FLAGS_automatic_punctuation,
FLAGS_profanity_filter, FLAGS_remove_profane_words, FLAGS_word_time_offsets,
FLAGS_automatic_punctuation,
/* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_chunk_duration_ms,
FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime,
FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score);
Expand Down
29 changes: 21 additions & 8 deletions riva/clients/asr/streaming_recognize_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ MicrophoneThreadMain(
StreamingRecognizeClient::StreamingRecognizeClient(
std::shared_ptr<grpc::Channel> channel, int32_t num_parallel_requests,
const std::string& language_code, int32_t max_alternatives, bool profanity_filter,
bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel,
bool print_transcripts, int32_t chunk_duration_ms, bool interim_results,
std::string output_filename, std::string model_name, bool simulate_realtime,
bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score)
bool remove_profane_words, bool word_time_offsets, bool automatic_punctuation,
bool separate_recognition_per_channel, bool print_transcripts, int32_t chunk_duration_ms,
bool interim_results, std::string output_filename, std::string model_name,
bool simulate_realtime, bool verbatim_transcripts, const std::string& boosted_phrases_file,
float boosted_phrases_score)
: print_latency_stats_(true), stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)),
language_code_(language_code), max_alternatives_(max_alternatives),
profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets),
automatic_punctuation_(automatic_punctuation),
profanity_filter_(profanity_filter), remove_profane_words_(remove_profane_words),
word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation),
separate_recognition_per_channel_(separate_recognition_per_channel),
print_transcripts_(print_transcripts), chunk_duration_ms_(chunk_duration_ms),
interim_results_(interim_results), total_audio_processed_(0.), num_streams_started_(0),
Expand Down Expand Up @@ -123,7 +124,13 @@ StreamingRecognizeClient::GenerateRequests(std::shared_ptr<ClientCall> call)
config->set_language_code(language_code_);
config->set_encoding(call->stream->wav->encoding);
config->set_max_alternatives(max_alternatives_);
config->set_profanity_filter(profanity_filter_);
config->set_profanity_filter(nr_asr::PROFANITY_OFF);
if (profanity_filter_) {
config->set_profanity_filter(nr_asr::PROFANITY_MASK);
}
if (remove_profane_words_) {
config->set_profanity_filter(nr_asr::PROFANITY_REMOVE);
}
config->set_audio_channel_count(call->stream->wav->channels);
config->set_enable_word_time_offsets(word_time_offsets_);
config->set_enable_automatic_punctuation(automatic_punctuation_);
Expand Down Expand Up @@ -376,7 +383,13 @@ StreamingRecognizeClient::DoStreamingFromMicrophone(
config->set_language_code(language_code_);
config->set_encoding(encoding);
config->set_max_alternatives(max_alternatives_);
config->set_profanity_filter(profanity_filter_);
config->set_profanity_filter(nr_asr::PROFANITY_OFF);
if (profanity_filter_) {
config->set_profanity_filter(nr_asr::PROFANITY_MASK);
}
if (remove_profane_words_) {
config->set_profanity_filter(nr_asr::PROFANITY_REMOVE);
}
config->set_audio_channel_count(channels);
config->set_enable_word_time_offsets(word_time_offsets_);
config->set_enable_automatic_punctuation(automatic_punctuation_);
Expand Down
9 changes: 5 additions & 4 deletions riva/clients/asr/streaming_recognize_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ class StreamingRecognizeClient {
StreamingRecognizeClient(
std::shared_ptr<grpc::Channel> channel, int32_t num_parallel_requests,
const std::string& language_code, int32_t max_alternatives, bool profanity_filter,
bool word_time_offsets, bool automatic_punctuation, bool separate_recognition_per_channel,
bool print_transcripts, int32_t chunk_duration_ms, bool interim_results,
std::string output_filename, std::string model_name, bool simulate_realtime,
bool verbatim_transcripts, const std::string& boosted_phrases_file,
bool remove_profane_words, bool word_time_offsets, bool automatic_punctuation,
bool separate_recognition_per_channel, bool print_transcripts, int32_t chunk_duration_ms,
bool interim_results, std::string output_filename, std::string model_name,
bool simulate_realtime, bool verbatim_transcripts, const std::string& boosted_phrases_file,
float boosted_phrases_score);

~StreamingRecognizeClient();
Expand Down Expand Up @@ -87,6 +87,7 @@ class StreamingRecognizeClient {
std::string language_code_;
int32_t max_alternatives_;
bool profanity_filter_;
bool remove_profane_words_;
int32_t channels_;
bool word_time_offsets_;
bool automatic_punctuation_;
Expand Down
4 changes: 2 additions & 2 deletions riva/clients/asr/streaming_recognize_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ TEST(StreamingRecognizeClient, num_responses_requests)
auto current_time = std::chrono::steady_clock::now();

StreamingRecognizeClient recognize_client(
grpc_channel, 1, "en-US", 1, false, false, false, false, false, 800, false, "dummy.txt",
"dummy", true, true, "", 10.);
grpc_channel, 1, "en-US", 1, false, false, false, false, false, false, 800, false,
"dummy.txt", "dummy", true, true, "", 10.);

std::shared_ptr<ClientCall> call = std::make_shared<ClientCall>(1, true);
uint32_t num_sends = 10;
Expand Down
6 changes: 5 additions & 1 deletion riva/clients/nmt/riva_nmt_streaming_s2s_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ DEFINE_string(
DEFINE_bool(
profanity_filter, false,
"Flag that controls if generated transcripts should be filtered for the profane words");
DEFINE_bool(
remove_profane_words, false, "Flag that controls if profane word need to be removed as well");
DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated");
DEFINE_bool(
simulate_realtime, false, "Flag that controls if audio files should be sent in realtime");
Expand Down Expand Up @@ -97,6 +99,7 @@ main(int argc, char** argv)
str_usage << " --audio_device=<device_id (such as hw:5,0)> " << std::endl;
str_usage << " --automatic_punctuation=<true|false>" << std::endl;
str_usage << " --profanity_filter=<true|false>" << std::endl;
str_usage << " --remove_profane_words=<true|false>" << std::endl;
str_usage << " --riva_uri=<server_name:port> " << std::endl;
str_usage << " --chunk_duration_ms=<integer> " << std::endl;
str_usage << " --simulate_realtime=<true|false> " << std::endl;
Expand Down Expand Up @@ -159,7 +162,8 @@ main(int argc, char** argv)

StreamingS2SClient recognize_client(
grpc_channel, FLAGS_num_parallel_requests, FLAGS_source_language_code,
FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_automatic_punctuation,
FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_remove_profane_words,
FLAGS_automatic_punctuation,
/* separate_recognition_per_channel*/ false, FLAGS_chunk_duration_ms, FLAGS_simulate_realtime,
FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score,
FLAGS_tts_encoding, FLAGS_tts_audio_file, FLAGS_tts_sample_rate, FLAGS_tts_voice_name);
Expand Down
6 changes: 5 additions & 1 deletion riva/clients/nmt/riva_nmt_streaming_s2t_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ DEFINE_string(
DEFINE_bool(
profanity_filter, false,
"Flag that controls if generated transcripts should be filtered for the profane words");
DEFINE_bool(
remove_profane_words, false, "Flag that controls if profane word need to be removed as well");
DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated");
DEFINE_bool(
simulate_realtime, false, "Flag that controls if audio files should be sent in realtime");
Expand Down Expand Up @@ -93,6 +95,7 @@ main(int argc, char** argv)
str_usage << " --audio_device=<device_id (such as hw:5,0)> " << std::endl;
str_usage << " --automatic_punctuation=<true|false>" << std::endl;
str_usage << " --profanity_filter=<true|false>" << std::endl;
str_usage << " --remove_profane_words=<true|false>" << std::endl;
str_usage << " --riva_uri=<server_name:port> " << std::endl;
str_usage << " --chunk_duration_ms=<integer> " << std::endl;
str_usage << " --simulate_realtime=<true|false> " << std::endl;
Expand Down Expand Up @@ -146,7 +149,8 @@ main(int argc, char** argv)

StreamingS2TClient recognize_client(
grpc_channel, FLAGS_num_parallel_requests, FLAGS_source_language_code,
FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_automatic_punctuation,
FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_remove_profane_words,
FLAGS_automatic_punctuation,
/* separate_recognition_per_channel*/ false, FLAGS_chunk_duration_ms, FLAGS_simulate_realtime,
FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score,
FLAGS_nmt_text_file);
Expand Down
Loading