Skip to content

asr: add direct gRPC client #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions scripts/asr/riva_streaming_asr_grpc_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Script to stream audio file to Riva and print FINAL transcripts with audio_processed info

import argparse
from pathlib import Path
import grpc
import riva.client
import riva.client.proto.riva_asr_pb2 as riva_asr_pb2
import riva.client.proto.riva_asr_pb2_grpc as riva_asr_pb2_grpc
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters


def read_responses(responses):
try:
final_transcript = ""
for response in responses:
if not response.results:
continue
for result in response.results:
if not result.alternatives:
continue
if result.is_final:
final_transcript += result.alternatives[0].transcript
print(f"FINAL: {result.audio_processed:.2f} : {result.alternatives[0].transcript}")
else:
print(f"PARTIAL: {result.audio_processed:.2f} : {result.alternatives[0].transcript}")

# print("Transcript:", final_transcript)

except grpc.RpcError as error:
print(error.code(), error.details())
return


def generate_requests(args):
print(f"File: {args.input_file}")
streaming_config = riva_asr_pb2.StreamingRecognitionConfig(
config=riva_asr_pb2.RecognitionConfig(
language_code="en-US", max_alternatives=1, profanity_filter=True, enable_automatic_punctuation=True,
),
interim_results=False,
)

# First send the config
yield riva_asr_pb2.StreamingRecognizeRequest(streaming_config=streaming_config)

# Followed by audio
try:
for audio_chunk in riva.client.AudioChunkFileIterator(args.input_file, args.chunk_duration_ms):
yield riva_asr_pb2.StreamingRecognizeRequest(audio_content=audio_chunk)
except Exception as e:
print(e)
return


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Streaming transcription via Riva AI Services. Uses direct gRPC API",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file", required=True, type=Path, help="Name of the WAV file with LINEAR_PCM encoding to transcribe."
)
parser.add_argument("--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds.")
parser.add_argument(
"--interim-results", default=False, action='store_true', help="Print intermediate transcripts",
)
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(
parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True
)
args = parser.parse_args()
if args.max_alternatives < 1:
parser.error("`--max-alternatives` must be greater than or equal to 1")
return args


def main() -> None:
args = parse_args()

# Open channel
auth = riva.client.Auth(None, use_ssl=args.use_ssl, uri=args.server, metadata_args=args.metadata)

# Create stub
riva_stub = riva_asr_pb2_grpc.RivaSpeechRecognitionStub(auth.channel)

# Get response stream to read transcripts
read_responses(riva_stub.StreamingRecognize(generate_requests(args)))


if __name__ == "__main__":
main()