Skip to content

Update client to use profanity settings #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: release/2.13.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions riva/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
add_word_boosting_to_config,
add_speaker_diarization_to_config,
get_wav_file_parameters,
get_profanity_setting,
print_offline,
print_streaming,
sleep_audio_length,
Expand Down
9 changes: 8 additions & 1 deletion riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def add_asr_config_argparse_parameters(
parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False
parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, remove_profane_words: bool = False, word_time_offsets: bool = False
) -> argparse.ArgumentParser:
if word_time_offsets:
parser.add_argument(
Expand All @@ -25,6 +25,13 @@ def add_asr_config_argparse_parameters(
action='store_true',
help="Flag that controls the profanity filtering in the generated transcripts",
)
if remove_profane_words:
parser.add_argument(
"--remove-profane-words",
default=False,
action='store_true',
help="Flag that removes profane words in the generated transcripts",
)
parser.add_argument(
"--automatic-punctuation",
default=False,
Expand Down
7 changes: 7 additions & 0 deletions riva/client/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def get_wav_file_parameters(input_file: Union[str, os.PathLike]) -> Dict[str, Un
def sleep_audio_length(audio_chunk: bytes, time_to_sleep: float) -> None:
time.sleep(time_to_sleep)

def get_profanity_setting(profanity_filter: bool, remove_profane_words: bool) -> rasr.ProfanitySettings :
profanity_filter_setting = rasr.ProfanitySettings.PROFANITY_OFF
if profanity_filter:
profanity_filter_setting = rasr.ProfanitySettings.PROFANITY_MASK
if remove_profane_words:
profanity_filter_setting = rasr.ProfanitySettings.PROFANITY_REMOVE
return profanity_filter_setting

class AudioChunkFileIterator:
def __init__(
Expand Down
6 changes: 3 additions & 3 deletions scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Union

import riva.client
from riva.client.asr import get_wav_file_parameters
from riva.client.asr import get_wav_file_parameters , get_profanity_setting
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters


Expand Down Expand Up @@ -38,7 +38,7 @@ def parse_args() -> argparse.Namespace:
"--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server."
)
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, remove_profane_words=True, word_time_offsets=True)
args = parser.parse_args()
if args.max_alternatives < 1:
parser.error("`--max-alternatives` must be greater than or equal to 1")
Expand All @@ -56,7 +56,7 @@ def streaming_transcription_worker(
config=riva.client.RecognitionConfig(
language_code=args.language_code,
max_alternatives=args.max_alternatives,
profanity_filter=args.profanity_filter,
profanity_filter=get_profanity_setting(args.profanity_filter,args.remove_profane_words),
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
enable_word_time_offsets=args.word_time_offsets,
Expand Down
4 changes: 2 additions & 2 deletions scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def parse_args() -> argparse.Namespace:
"--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript."
)
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True,remove_profane_words=True, word_time_offsets=True)
args = parser.parse_args()
if not args.list_devices and args.input_file is None:
parser.error(
Expand All @@ -72,7 +72,7 @@ def main() -> None:
config=riva.client.RecognitionConfig(
language_code=args.language_code,
max_alternatives=1,
profanity_filter=args.profanity_filter,
profanity_filter=riva.client.get_profanity_setting(args.profanity_filter,args.remove_profane_words),
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
),
Expand Down
4 changes: 2 additions & 2 deletions scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument("--input-file", required=True, type=Path, help="A path to a local file to transcribe.")
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, remove_profane_words=True, word_time_offsets=True)
args = parser.parse_args()
args.input_file = args.input_file.expanduser()
return args
Expand All @@ -31,7 +31,7 @@ def main() -> None:
config = riva.client.RecognitionConfig(
language_code=args.language_code,
max_alternatives=args.max_alternatives,
profanity_filter=args.profanity_filter,
profanity_filter=riva.client.get_profanity_setting(args.profanity_filter,args.remove_profane_words),
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
Expand Down
4 changes: 2 additions & 2 deletions scripts/asr/transcribe_mic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument("--input-device", type=int, default=default_device_index, help="An input audio device to use.")
parser.add_argument("--list-devices", action="store_true", help="List input audio device indices.")
parser = add_asr_config_argparse_parameters(parser, profanity_filter=True)
parser = add_asr_config_argparse_parameters(parser, profanity_filter=True, remove_profane_words=True)
parser = add_connection_argparse_parameters(parser)
parser.add_argument(
"--sample-rate-hz",
Expand Down Expand Up @@ -48,7 +48,7 @@ def main() -> None:
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=args.language_code,
max_alternatives=1,
profanity_filter=args.profanity_filter,
profanity_filter=riva.client.get_profanity_setting(args.profanity_filter,args.remove_profane_words),
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
sample_rate_hertz=args.sample_rate_hz,
Expand Down