Skip to content

Commit 6db03dc

Browse files
support word boosting from file
1 parent 0532a1f commit 6db03dc

File tree

4 files changed

+67
-38
lines changed

4 files changed

+67
-38
lines changed

riva/client/argparse_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: MIT
33

44
import argparse
5+
from pathlib import Path
56

67

78
def add_asr_config_argparse_parameters(
@@ -35,16 +36,18 @@ def add_asr_config_argparse_parameters(
3536
help="Flag that controls if transcript should be automatically punctuated",
3637
)
3738
parser.add_argument(
38-
"--no-verbatim-transcripts",
39-
default=False,
40-
action='store_true',
41-
help="If specified, text inverse normalization will be applied",
39+
"--verbatim-transcripts",
40+
default=True,
41+
action='store_false',
42+
help="True returns text exactly as it was said. False applies Inverse text normalization",
4243
)
4344
parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.")
4445
parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.")
45-
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.")
4646
parser.add_argument(
47-
"--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding."
47+
"--boosted-words-file", default=None, type=Path, help="File with a list of words to boost. One line per word."
48+
)
49+
parser.add_argument(
50+
"--boosted-words-score", type=float, default=4.0, help="Score by which to boost the boosted words."
4851
)
4952
parser.add_argument(
5053
"--speaker-diarization",

riva/client/asr.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,19 @@ def __next__(self) -> bytes:
9797

9898
def add_word_boosting_to_config(
9999
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig],
100-
boosted_lm_words: Optional[List[str]],
101-
boosted_lm_score: float,
100+
boosted_words_file: Union[str, os.PathLike],
101+
boosted_words_score: float,
102102
) -> None:
103103
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
104-
if boosted_lm_words is not None:
104+
boosted_words = []
105+
if boosted_words_file:
106+
with open(boosted_words_file) as f:
107+
boosted_words = f.read().splitlines()
108+
109+
if boosted_words is not None:
105110
speech_context = rasr.SpeechContext()
106-
speech_context.phrases.extend(boosted_lm_words)
107-
speech_context.boost = boosted_lm_score
111+
speech_context.phrases.extend(boosted_words)
112+
speech_context.boost = boosted_words_score
108113
inner_config.speech_contexts.append(speech_context)
109114

110115

scripts/asr/riva_streaming_asr_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def parse_args() -> argparse.Namespace:
2626
parser.add_argument("--num-parallel-requests", default=1, type=int, help="Number of client threads.")
2727
parser.add_argument("--num-iterations", default=1, type=int, help="Number of iterations over the file.")
2828
parser.add_argument(
29-
"--input-file", required=True, type=str, help="Name of the WAV file with LINEAR_PCM encoding to transcribe."
29+
"--input-file", required=True, type=Path, help="Name of the WAV file with LINEAR_PCM encoding to transcribe."
3030
)
3131
parser.add_argument(
3232
"--simulate-realtime",

scripts/asr/transcribe_file.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# SPDX-License-Identifier: MIT
33

44
import argparse
5+
import json
6+
from pathlib import Path
57

68
import riva.client
79
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
@@ -15,7 +17,12 @@ def parse_args() -> argparse.Namespace:
1517
"`--play-audio` or `--output-device`.",
1618
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1719
)
18-
parser.add_argument("--input-file", help="A path to a local file to stream.")
20+
parser.add_argument(
21+
"--input-file",
22+
required=True,
23+
type=Path,
24+
help="A path to a local file to stream or a JSONL file containing list of files. JSONL file should contain JSON entry on each line, for example: {'audio_filepath': 'audio.wav'} ",
25+
)
1926
parser.add_argument("--list-devices", action="store_true", help="List output devices indices")
2027
parser.add_argument(
2128
"--interim-results", default=False, action='store_true', help="Print intermediate transcripts",
@@ -63,6 +70,17 @@ def main() -> None:
6370
if args.list_devices:
6471
riva.client.audio_io.list_output_devices()
6572
return
73+
input_files = []
74+
if args.input_file.suffix == ".json":
75+
with open(args.input_file) as f:
76+
lines = f.read().splitlines()
77+
for line in lines:
78+
data = json.loads(line)
79+
if "audio_filepath" in data:
80+
input_files.append(data["audio_filepath"])
81+
else:
82+
input_files = [args.input_file]
83+
6684
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
6785
asr_service = riva.client.ASRService(auth)
6886
config = riva.client.StreamingRecognitionConfig(
@@ -71,37 +89,40 @@ def main() -> None:
7189
max_alternatives=args.max_alternatives,
7290
profanity_filter=args.profanity_filter,
7391
enable_automatic_punctuation=args.automatic_punctuation,
74-
verbatim_transcripts=not args.no_verbatim_transcripts,
92+
verbatim_transcripts=args.verbatim_transcripts,
7593
enable_word_time_offsets=args.word_time_offsets,
7694
model=args.model_name,
7795
),
7896
interim_results=args.interim_results,
7997
)
80-
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
98+
riva.client.add_word_boosting_to_config(config, args.boosted_words_file, args.boosted_words_score)
8199
sound_callback = None
82-
try:
83-
if args.play_audio or args.output_device is not None:
84-
wp = riva.client.get_wav_file_parameters(args.input_file)
85-
sound_callback = riva.client.audio_io.SoundCallBack(
86-
args.output_device, wp['sampwidth'], wp['nchannels'], wp['framerate'],
87-
)
88-
delay_callback = sound_callback
89-
else:
90-
delay_callback = riva.client.sleep_audio_length if args.simulate_realtime else None
91-
with riva.client.AudioChunkFileIterator(
92-
args.input_file, args.chunk_duration_ms, delay_callback,
93-
) as audio_chunk_iterator:
94-
riva.client.print_streaming(
95-
responses=asr_service.streaming_response_generator(
96-
audio_chunks=audio_chunk_iterator, streaming_config=config,
97-
),
98-
input_file=args.input_file,
99-
show_intermediate=args.interim_results,
100-
additional_info="confidence" if args.print_confidence else "no",
101-
)
102-
finally:
103-
if sound_callback is not None and sound_callback.opened:
104-
sound_callback.close()
100+
101+
for file in input_files:
102+
try:
103+
if args.play_audio or args.output_device is not None:
104+
wp = riva.client.get_wav_file_parameters(file)
105+
sound_callback = riva.client.audio_io.SoundCallBack(
106+
args.output_device, wp['sampwidth'], wp['nchannels'], wp['framerate'],
107+
)
108+
delay_callback = sound_callback
109+
else:
110+
delay_callback = riva.client.sleep_audio_length if args.simulate_realtime else None
111+
112+
with riva.client.AudioChunkFileIterator(
113+
file, args.chunk_duration_ms, delay_callback,
114+
) as audio_chunk_iterator:
115+
riva.client.print_streaming(
116+
responses=asr_service.streaming_response_generator(
117+
audio_chunks=audio_chunk_iterator, streaming_config=config,
118+
),
119+
input_file=file,
120+
show_intermediate=args.interim_results,
121+
additional_info="confidence" if args.print_confidence else "no",
122+
)
123+
finally:
124+
if sound_callback is not None and sound_callback.opened:
125+
sound_callback.close()
105126

106127

107128
if __name__ == "__main__":

0 commit comments

Comments
 (0)