diff --git a/riva/client/__init__.py b/riva/client/__init__.py index d4a22785..fc93ee03 100644 --- a/riva/client/__init__.py +++ b/riva/client/__init__.py @@ -8,6 +8,7 @@ add_word_boosting_to_config, add_speaker_diarization_to_config, get_wav_file_parameters, + get_profanity_setting, print_offline, print_streaming, sleep_audio_length, diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index 6425a700..21f9fdfa 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -5,7 +5,7 @@ def add_asr_config_argparse_parameters( - parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False + parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, remove_profane_words: bool = False, word_time_offsets: bool = False ) -> argparse.ArgumentParser: if word_time_offsets: parser.add_argument( @@ -25,6 +25,13 @@ def add_asr_config_argparse_parameters( action='store_true', help="Flag that controls the profanity filtering in the generated transcripts", ) + if remove_profane_words: + parser.add_argument( + "--remove-profane-words", + default=False, + action='store_true', + help="Flag that removes profane words in the generated transcripts", + ) parser.add_argument( "--automatic-punctuation", default=False, diff --git a/riva/client/asr.py b/riva/client/asr.py index 3a0da418..03335cd0 100644 --- a/riva/client/asr.py +++ b/riva/client/asr.py @@ -41,6 +41,13 @@ def get_wav_file_parameters(input_file: Union[str, os.PathLike]) -> Dict[str, Un def sleep_audio_length(audio_chunk: bytes, time_to_sleep: float) -> None: time.sleep(time_to_sleep) +def get_profanity_setting(profanity_filter: bool, remove_profane_words: bool) -> rasr.ProfanitySettings : + profanity_filter_setting = rasr.ProfanitySettings.PROFANITY_OFF + if profanity_filter: + profanity_filter_setting = rasr.ProfanitySettings.PROFANITY_MASK + if remove_profane_words: + profanity_filter_setting = rasr.ProfanitySettings.PROFANITY_REMOVE + return profanity_filter_setting class AudioChunkFileIterator: def __init__( diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index f30fe2da..d2942ab2 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -10,7 +10,7 @@ from typing import Union import riva.client -from riva.client.asr import get_wav_file_parameters +from riva.client.asr import get_wav_file_parameters , get_profanity_setting from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters @@ -38,7 +38,7 @@ def parse_args() -> argparse.Namespace: "--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server." ) parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, remove_profane_words=True, word_time_offsets=True) args = parser.parse_args() if args.max_alternatives < 1: parser.error("`--max-alternatives` must be greater than or equal to 1") @@ -56,7 +56,7 @@ def streaming_transcription_worker( config=riva.client.RecognitionConfig( language_code=args.language_code, max_alternatives=args.max_alternatives, - profanity_filter=args.profanity_filter, + profanity_filter=get_profanity_setting(args.profanity_filter,args.remove_profane_words), enable_automatic_punctuation=args.automatic_punctuation, verbatim_transcripts=not args.no_verbatim_transcripts, enable_word_time_offsets=args.word_time_offsets, diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index 32b9c8d3..fd144695 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -49,7 +49,7 @@ def parse_args() -> argparse.Namespace: "--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript." ) parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True,remove_profane_words=True, word_time_offsets=True) args = parser.parse_args() if not args.list_devices and args.input_file is None: parser.error( @@ -72,7 +72,7 @@ def main() -> None: config=riva.client.RecognitionConfig( language_code=args.language_code, max_alternatives=1, - profanity_filter=args.profanity_filter, + profanity_filter=riva.client.get_profanity_setting(args.profanity_filter,args.remove_profane_words), enable_automatic_punctuation=args.automatic_punctuation, verbatim_transcripts=not args.no_verbatim_transcripts, ), diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 82c52173..85a3e07b 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -18,7 +18,7 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--input-file", required=True, type=Path, help="A path to a local file to transcribe.") parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, remove_profane_words=True, word_time_offsets=True) args = parser.parse_args() args.input_file = args.input_file.expanduser() return args @@ -31,7 +31,7 @@ def main() -> None: config = riva.client.RecognitionConfig( language_code=args.language_code, max_alternatives=args.max_alternatives, - profanity_filter=args.profanity_filter, + profanity_filter=riva.client.get_profanity_setting(args.profanity_filter,args.remove_profane_words), enable_automatic_punctuation=args.automatic_punctuation, verbatim_transcripts=not args.no_verbatim_transcripts, enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization, diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index 1924c05b..4d40949a 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -18,7 +18,7 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--input-device", type=int, default=default_device_index, help="An input audio device to use.") parser.add_argument("--list-devices", action="store_true", help="List input audio device indices.") - parser = add_asr_config_argparse_parameters(parser, profanity_filter=True) + parser = add_asr_config_argparse_parameters(parser, profanity_filter=True, remove_profane_words=True) parser = add_connection_argparse_parameters(parser) parser.add_argument( "--sample-rate-hz", @@ -48,7 +48,7 @@ def main() -> None: encoding=riva.client.AudioEncoding.LINEAR_PCM, language_code=args.language_code, max_alternatives=1, - profanity_filter=args.profanity_filter, + profanity_filter=riva.client.get_profanity_setting(args.profanity_filter,args.remove_profane_words), enable_automatic_punctuation=args.automatic_punctuation, verbatim_transcripts=not args.no_verbatim_transcripts, sample_rate_hertz=args.sample_rate_hz,