Skip to content

Commit ebc2c05

Browse files
authored
Add endpointing config parameters to ASR clients (#80)
* asr: add eou param to py clients * feat(asr):rename params * asr: rename variable * update default values and checks * asr: add validation check * asr: update gitmodule * asr:update gutsubmodule * asr: update protos with main branch * asr: update .gitmodules
1 parent 9599157 commit ebc2c05

File tree

9 files changed

+91
-4
lines changed

9 files changed

+91
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ its purpose and parameters.
8484

8585
#### ASR
8686

87-
You may find a detailed documentation [here](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/apis/development-cpp.html).
87+
You may find a detailed documentation [here](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/apis/cli.html).
8888

8989
For transcribing in streaming mode you may use `scripts/asr/transcribe_file.py`.
9090
```bash

common

riva/client/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
print_offline,
1212
print_streaming,
1313
sleep_audio_length,
14+
add_endpoint_parameters_to_config,
1415
)
1516
from riva.client.auth import Auth
1617
from riva.client.nlp import (
@@ -33,7 +34,7 @@
3334
__shortversion__,
3435
__version__,
3536
)
36-
from riva.client.proto.riva_asr_pb2 import RecognitionConfig, StreamingRecognitionConfig
37+
from riva.client.proto.riva_asr_pb2 import RecognitionConfig, StreamingRecognitionConfig, EndpointingConfig
3738
from riva.client.proto.riva_audio_pb2 import AudioEncoding
3839
from riva.client.proto.riva_nlp_pb2 import AnalyzeIntentOptions
3940
from riva.client.proto.riva_nmt_pb2 import StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig, StreamingTranslateSpeechToTextConfig

riva/client/argparse_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,36 @@ def add_asr_config_argparse_parameters(
4848
action='store_true',
4949
help="Flag that controls if speaker diarization should be performed",
5050
)
51+
parser.add_argument(
52+
"--start-history",
53+
default=-1,
54+
type=int,
55+
help="Value to detect and initiate start of speech utterance",
56+
)
57+
parser.add_argument(
58+
"--start-threshold",
59+
default=-1.0,
60+
type=float,
61+
help="Threshold value for detecting the start of speech utterance",
62+
)
63+
parser.add_argument(
64+
"--stop-history",
65+
default=-1,
66+
type=int,
67+
help="Value to reset the endpoint detection history",
68+
)
69+
parser.add_argument(
70+
"--stop-history-eou",
71+
default=-1,
72+
type=int,
73+
help="Value to determine the response history for endpoint detection",
74+
)
75+
parser.add_argument(
76+
"--stop-threshold",
77+
default=-1.0,
78+
type=float,
79+
help="Threshold value for detecting the end of speech utterance",
80+
)
5181
return parser
5282

5383

riva/client/asr.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,31 @@ def add_speaker_diarization_to_config(
123123
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
124124
inner_config.diarization_config.CopyFrom(diarization_config)
125125

126+
def add_endpoint_parameters_to_config(
127+
config: Union[rasr.RecognitionConfig, rasr.EndpointingConfig],
128+
start_history: int,
129+
start_threshold: float,
130+
stop_history: int,
131+
stop_history_eou: int,
132+
stop_threshold: float,
133+
) -> None:
134+
if not (start_history > 0 or start_threshold > 0 or stop_history > 0 or stop_history_eou > 0 or stop_threshold > 0):
135+
return
136+
137+
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
138+
endpointing_config = rasr.EndpointingConfig()
139+
if start_history > 0:
140+
endpointing_config.start_history = start_history
141+
if start_threshold > 0:
142+
endpointing_config.start_threshold = start_threshold
143+
if stop_history > 0:
144+
endpointing_config.stop_history = stop_history
145+
if stop_history_eou > 0:
146+
endpointing_config.stop_history_eou = stop_history_eou
147+
if stop_threshold > 0:
148+
endpointing_config.stop_threshold = stop_threshold
149+
inner_config.endpointing_config.CopyFrom(endpointing_config)
150+
126151

127152
PRINT_STREAMING_ADDITIONAL_INFO_MODES = ['no', 'time', 'confidence']
128153

scripts/asr/riva_streaming_asr_client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def streaming_transcription_worker(
6363
),
6464
interim_results=True,
6565
)
66+
riva.client.add_endpoint_parameters_to_config(
67+
config,
68+
args.start_history,
69+
args.start_threshold,
70+
args.stop_history,
71+
args.stop_history_eou,
72+
args.stop_threshold
73+
)
6674
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
6775
for _ in range(args.num_iterations):
6876
with riva.client.AudioChunkFileIterator(

scripts/asr/transcribe_file.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ def main() -> None:
7979
interim_results=True,
8080
)
8181
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
82+
riva.client.add_endpoint_parameters_to_config(
83+
config,
84+
args.start_history,
85+
args.start_threshold,
86+
args.stop_history,
87+
args.stop_history_eou,
88+
args.stop_threshold
89+
)
8290
sound_callback = None
8391
try:
8492
if args.play_audio or args.output_device is not None:

scripts/asr/transcribe_file_offline.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,14 @@ def main() -> None:
3838
)
3939
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
4040
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
41-
41+
riva.client.add_endpoint_parameters_to_config(
42+
config,
43+
args.start_history,
44+
args.start_threshold,
45+
args.stop_history,
46+
args.stop_history_eou,
47+
args.stop_threshold
48+
)
4249
with args.input_file.open('rb') as fh:
4350
data = fh.read()
4451
try:

scripts/asr/transcribe_mic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def main() -> None:
5757
interim_results=True,
5858
)
5959
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
60+
riva.client.add_endpoint_parameters_to_config(
61+
config,
62+
args.start_history,
63+
args.start_threshold,
64+
args.stop_history,
65+
args.stop_history_eou,
66+
args.stop_threshold
67+
)
6068
with riva.client.audio_io.MicrophoneStream(
6169
args.sample_rate_hz,
6270
args.file_streaming_chunk,

0 commit comments

Comments
 (0)