Skip to content

asr client updates #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# SPDX-License-Identifier: MIT

import argparse
from pathlib import Path


def add_asr_config_argparse_parameters(
parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False
parser: argparse.ArgumentParser,
max_alternatives: bool = False,
profanity_filter: bool = False,
word_time_offsets: bool = False,
) -> argparse.ArgumentParser:
if word_time_offsets:
parser.add_argument(
Expand All @@ -20,27 +24,30 @@ def add_asr_config_argparse_parameters(
)
if profanity_filter:
parser.add_argument(
"--profanity-filter",
default=False,
action='store_true',
help="Flag that controls the profanity filtering in the generated transcripts",
)
"--profanity-filter",
default=False,
action='store_true',
help="Flag that controls the profanity filtering in the generated transcripts",
)
parser.add_argument(
"--automatic-punctuation",
default=False,
action='store_true',
help="Flag that controls if transcript should be automatically punctuated",
)
parser.add_argument(
"--no-verbatim-transcripts",
"--verbatim-transcripts",
default=False,
action='store_true',
help="If specified, text inverse normalization will be applied",
help="Flag to disable Inverse text normalization and return the text exactly as it was said",
)
parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.")
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.")
parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.")
parser.add_argument(
"--boosted-words-file", default=None, type=Path, help="File with a list of words to boost. One line per word."
)
parser.add_argument(
"--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding."
"--boosted-words-score", type=float, default=4.0, help="Score by which to boost the boosted words."
)
parser.add_argument(
"--speaker-diarization",
Expand Down
56 changes: 37 additions & 19 deletions riva/client/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_wav_file_parameters(input_file: Union[str, os.PathLike]) -> Dict[str, Un
'duration': nframes / rate,
'nchannels': wf.getnchannels(),
'sampwidth': wf.getsampwidth(),
'data_offset': wf.getfp().size_read + wf.getfp().offset
'data_offset': wf.getfp().size_read + wf.getfp().offset,
}
except:
# Not a WAV file
Expand All @@ -46,11 +46,11 @@ class AudioChunkFileIterator:
def __init__(
self,
input_file: Union[str, os.PathLike],
chunk_n_frames: int,
chunk_duration_ms: int,
delay_callback: Optional[Callable[[bytes, float], None]] = None,
) -> None:
self.input_file: Path = Path(input_file).expanduser()
self.chunk_n_frames = chunk_n_frames
self.chunk_duration_ms = chunk_duration_ms
self.delay_callback = delay_callback
self.file_parameters = get_wav_file_parameters(self.input_file)
self.file_object: Optional[typing.BinaryIO] = open(str(self.input_file), 'rb')
Expand All @@ -75,37 +75,46 @@ def __iter__(self):

def __next__(self) -> bytes:
if self.file_parameters:
data = self.file_object.read(self.chunk_n_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels'])
num_frames = int(self.chunk_duration_ms * self.file_parameters['framerate'] / 1000)
data = self.file_object.read(
num_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels']
)
else:
data = self.file_object.read(self.chunk_n_frames)
# Fixed chunk size when file_parameters is not available
data = self.file_object.read(8192)
if not data:
self.close()
raise StopIteration
if self.delay_callback is not None:
offset = self.file_parameters['data_offset'] if self.first_buffer else 0
self.delay_callback(
data[offset:], (len(data) - offset) / self.file_parameters['sampwidth'] / self.file_parameters['framerate']
data[offset:],
(len(data) - offset) / self.file_parameters['sampwidth'] / self.file_parameters['framerate'],
)
self.first_buffer = False
return data


def add_word_boosting_to_config(
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig],
boosted_lm_words: Optional[List[str]],
boosted_lm_score: float,
boosted_words_file: Union[str, os.PathLike],
boosted_words_score: float,
) -> None:
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
if boosted_lm_words is not None:
boosted_words = []
if boosted_words_file:
with open(boosted_words_file) as f:
boosted_words = f.read().splitlines()

if boosted_words is not None:
speech_context = rasr.SpeechContext()
speech_context.phrases.extend(boosted_lm_words)
speech_context.boost = boosted_lm_score
speech_context.phrases.extend(boosted_words)
speech_context.boost = boosted_words_score
inner_config.speech_contexts.append(speech_context)


def add_audio_file_specs_to_config(
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig],
audio_file: Union[str, os.PathLike],
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig], audio_file: Union[str, os.PathLike],
) -> None:
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
wav_parameters = get_wav_file_parameters(audio_file)
Expand All @@ -114,10 +123,7 @@ def add_audio_file_specs_to_config(
inner_config.audio_channel_count = wav_parameters['nchannels']


def add_speaker_diarization_to_config(
config: Union[rasr.RecognitionConfig],
diarization_enable: bool,
) -> None:
def add_speaker_diarization_to_config(config: Union[rasr.RecognitionConfig], diarization_enable: bool,) -> None:
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
if diarization_enable:
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
Expand All @@ -129,6 +135,7 @@ def add_speaker_diarization_to_config(

def print_streaming(
responses: Iterable[rasr.StreamingRecognizeResponse],
input_file: str = None,
output_file: Optional[Union[Union[os.PathLike, str, TextIO], List[Union[os.PathLike, str, TextIO]]]] = None,
additional_info: str = 'no',
word_time_offsets: bool = False,
Expand Down Expand Up @@ -194,6 +201,10 @@ def print_streaming(
output_file[i] = Path(elem).expanduser().open(file_mode)
start_time = time.time() # used in 'time` additional_info
num_chars_printed = 0 # used in 'no' additional_info
final_transcript = "" # for printing best final transcript
if input_file:
for f in output_file:
f.write(f"File: {input_file}\n")
for response in responses:
if not response.results:
continue
Expand All @@ -204,6 +215,7 @@ def print_streaming(
transcript = result.alternatives[0].transcript
if additional_info == 'no':
if result.is_final:
final_transcript += transcript
if show_intermediate:
overwrite_chars = ' ' * (num_chars_printed - len(transcript))
for i, f in enumerate(output_file):
Expand All @@ -221,10 +233,11 @@ def print_streaming(
partial_transcript += transcript
elif additional_info == 'time':
if result.is_final:
final_transcript += transcript
for i, alternative in enumerate(result.alternatives):
for f in output_file:
f.write(
f"Time {time.time() - start_time:.2f}s: Transcript {i}: {alternative.transcript}\n"
f"Time {time.time() - start_time:.2f}s: Final Transcript {i}: Audio Processed {result.audio_processed}: {alternative.transcript}\n"
)
if word_time_offsets:
for f in output_file:
Expand All @@ -239,6 +252,7 @@ def print_streaming(
partial_transcript += transcript
else: # additional_info == 'confidence'
if result.is_final:
final_transcript += transcript
for f in output_file:
f.write(f'## {transcript}\n')
f.write(f'Confidence: {result.alternatives[0].confidence:9.4f}\n')
Expand All @@ -255,10 +269,13 @@ def print_streaming(
elif additional_info == 'time':
for f in output_file:
if partial_transcript:
f.write(f">>>Time {time.time():.2f}s: {partial_transcript}\n")
f.write(f">>>Time {time.time() - start_time:.2f}s: {partial_transcript}\n")
else:
for f in output_file:
f.write('----\n')
for f in output_file:
f.write(f"Final transcripts:\n")
f.write(f"0 : {final_transcript}\n")
finally:
for fo, elem in zip(file_opened, output_file):
if fo:
Expand All @@ -284,6 +301,7 @@ def streaming_request_generator(

class ASRService:
"""Provides streaming and offline recognition services. Calls gRPC stubs with authentication metadata."""

def __init__(self, auth: Auth) -> None:
"""
Initializes an instance of the class.
Expand Down
10 changes: 4 additions & 6 deletions riva/client/audio_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
# SPDX-License-Identifier: MIT

import queue
from typing import Dict, Union, Optional
from typing import Dict, Optional, Union

import pyaudio


class MicrophoneStream:
"""Opens a recording stream as responses yielding the audio chunks."""

def __init__(self, rate: int, chunk: int, device: int = None) -> None:
def __init__(self, rate: int, chunk_duration_ms: int, device: int = None) -> None:
self._rate = rate
self._chunk = chunk
self._chunk = int(chunk_duration_ms * rate / 1000)
self._device = device

# Create a thread-safe buffer of audio data
Expand Down Expand Up @@ -115,9 +115,7 @@ def list_input_devices() -> None:


class SoundCallBack:
def __init__(
self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int,
) -> None:
def __init__(self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int,) -> None:
self.pa = pyaudio.PyAudio()
self.stream = self.pa.open(
output_device_index=output_device_index,
Expand Down
34 changes: 19 additions & 15 deletions scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import Union

import riva.client
from riva.client.asr import get_wav_file_parameters
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from riva.client.asr import get_wav_file_parameters


def parse_args() -> argparse.Namespace:
Expand All @@ -23,22 +23,25 @@ def parse_args() -> argparse.Namespace:
"which names follow a format `output_<thread_num>.txt`.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--num-clients", default=1, type=int, help="Number of client threads.")
parser.add_argument("--num-parallel-requests", default=1, type=int, help="Number of client threads.")
parser.add_argument("--num-iterations", default=1, type=int, help="Number of iterations over the file.")
parser.add_argument(
"--input-file", required=True, type=str, help="Name of the WAV file with LINEAR_PCM encoding to transcribe."
"--input-file", required=True, type=Path, help="Name of the WAV file with LINEAR_PCM encoding to transcribe."
)
parser.add_argument(
"--simulate-realtime",
action='store_true',
help="Option to simulate realtime transcription. Audio fragments are sent to a server at a pace that mimics "
"normal speech.",
)
parser.add_argument("--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds.")
parser.add_argument(
"--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server."
"--interim-results", default=False, action='store_true', help="Print intermediate transcripts",
)
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
parser = add_asr_config_argparse_parameters(
parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True
)
args = parser.parse_args()
if args.max_alternatives < 1:
parser.error("`--max-alternatives` must be greater than or equal to 1")
Expand All @@ -58,26 +61,26 @@ def streaming_transcription_worker(
max_alternatives=args.max_alternatives,
profanity_filter=args.profanity_filter,
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
verbatim_transcripts=not args.verbatim_transcripts,
enable_word_time_offsets=args.word_time_offsets,
model=args.model_name,
),
interim_results=True,
interim_results=args.interim_results,
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_word_boosting_to_config(config, args.boosted_words_file, args.boosted_words_score)
for _ in range(args.num_iterations):
with riva.client.AudioChunkFileIterator(
args.input_file,
args.file_streaming_chunk,
args.chunk_duration_ms,
delay_callback=riva.client.sleep_audio_length if args.simulate_realtime else None,
) as audio_chunk_iterator:
riva.client.print_streaming(
responses=asr_service.streaming_response_generator(
audio_chunks=audio_chunk_iterator,
streaming_config=config,
audio_chunks=audio_chunk_iterator, streaming_config=config,
),
input_file=args.input_file,
output_file=output_file,
additional_info='time',
file_mode='a',
word_time_offsets=args.word_time_offsets,
)
except BaseException as e:
Expand All @@ -87,12 +90,12 @@ def streaming_transcription_worker(

def main() -> None:
args = parse_args()
print("Number of clients:", args.num_clients)
print("Number of clients:", args.num_parallel_requests)
print("Number of iteration:", args.num_iterations)
print("Input file:", args.input_file)
threads = []
exception_queue = queue.Queue()
for i in range(args.num_clients):
for i in range(args.num_parallel_requests):
t = Thread(target=streaming_transcription_worker, args=[args, f"output_{i:d}.txt", i, exception_queue])
t.start()
threads.append(t)
Expand All @@ -112,7 +115,8 @@ def main() -> None:
if all_dead:
break
time.sleep(0.05)
print(str(args.num_clients), "threads done, output written to output_<thread_id>.txt")
for i in range(args.num_parallel_requests):
print(f"Thread {i} done, output written to output_{i}.txt")


if __name__ == "__main__":
Expand Down
Loading