Skip to content

feat: s2s client #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions scripts/nmt/s2s_mic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

import argparse
import wave
import riva.client
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from typing import Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union
import riva.client.audio_io
import riva.client.proto.riva_nmt_pb2 as riva_nmt

def parse_args() -> argparse.Namespace:
default_device_info = riva.client.audio_io.get_default_input_device_info()
default_device_index = None if default_device_info is None else default_device_info['index']
parser = argparse.ArgumentParser(
description="Streaming speech to speech translation from microphone via Riva AI Services",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--input-device", type=int, default=default_device_index, help="An input audio device to use.")
parser.add_argument("--list-input-devices", action="store_true", help="List input audio device indices.")
parser.add_argument("--list-output-devices", action="store_true", help="List input audio device indices.")
parser.add_argument("--output-device", type=int, help="Output device to use.")
parser.add_argument("--target-language-code", default="en-US", help="Language code of the output language.")
parser.add_argument("--tts-voice-name", default="English-US.Female-1", help="Voice name of the TTS model")
parser.add_argument(
"--play-audio",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If --play-audio is not set, then the script doesn't give any output. We probably should add --output parameter as in tts/talk.py so that the script could produce some output on server.

action="store_true",
help="Play input audio simultaneously with transcribing and translating it. If `--output-device` is not provided, "
"then the default output audio device will be used.",
)

parser = add_asr_config_argparse_parameters(parser, profanity_filter=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll probably need to set max_alternatives=False and word_time_offsets=False because these parameters are pointless for the script. Do you think we also need to add speaker_diarization=False flag?

parser = add_connection_argparse_parameters(parser)
parser.add_argument(
"--sample-rate-hz",
type=int,
help="A number of frames per second in audio streamed from a microphone.",
default=16000,
)
parser.add_argument(
"--file-streaming-chunk",
type=int,
default=1600,
help="A maximum number of frames in a audio chunk sent to server.",
)
args = parser.parse_args()
return args

def play_responses(responses: Iterable[riva_nmt.StreamingTranslateSpeechToSpeechResponse],
sound_stream) -> None:
count = 0
for response in responses:
#if first:
#print(f"time to first audio {(stop - start):.3f}s")
# first=False
if sound_stream is not None:
sound_stream(response.speech.audio)
fname = "response" + str(count)
out_f = wave.open(fname, 'wb')
out_f.setnchannels(1)
out_f.setsampwidth(2)
out_f.setframerate(44100)
count += 1


def main() -> None:
args = parse_args()
sound_stream = None
sampwidth = 2
nchannels = 1
Comment on lines +69 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampwidth and nchannels are set in 2 places: here and in play_responses() function. Could you make global variables?

if args.list_input_devices:
riva.client.audio_io.list_input_devices()
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return
return
if args.list_output_devices:
riva.client.audio_io.list_output_devices()
return

if args.output_device is not None or args.play_audio:
print("playing audio")
sound_stream = riva.client.audio_io.SoundCallBack(
args.output_device, nchannels=nchannels, sampwidth=sampwidth, framerate=44100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should make framerate a parameter of the script, like --sample-rate-hz in the script tts/talk.py?

)
print(sound_stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this print?

first = True # first tts output chunk received
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
nmt_service = riva.client.NeuralMachineTranslationClient(auth)
s2s_config = riva.client.StreamingTranslateSpeechToSpeechConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a tts_config as in proto? If so, then we could add a add_tts_config_argparse_parameters() function to argparse_utils.py function and refactor tts/talk.py using this function.

asr_config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=args.language_code,
max_alternatives=1,
profanity_filter=args.profanity_filter,
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
sample_rate_hertz=args.sample_rate_hz,
audio_channel_count=1,
),
interim_results=True,
),
translation_config = riva.client.TranslationConfig(
target_language_code=args.target_language_code,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here should be source_language_code and, probably, model_name as in config.

),
tts_config = riva.client.SynthesizeSpeechConfig(
encoding=1,
sample_rate_hz=44100,
voice_name=args.tts_voice_name,
language_code=args.target_language_code,
),
)

#riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
with riva.client.audio_io.MicrophoneStream(
args.sample_rate_hz,
args.file_streaming_chunk,
device=args.input_device,
) as audio_chunk_iterator:
play_responses(responses=nmt_service.streaming_s2s_response_generator(
audio_chunks=audio_chunk_iterator,
streaming_config=s2s_config), sound_stream=sound_stream)
Comment on lines +114 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
play_responses(responses=nmt_service.streaming_s2s_response_generator(
audio_chunks=audio_chunk_iterator,
streaming_config=s2s_config), sound_stream=sound_stream)
play_responses(
responses=nmt_service.streaming_s2s_response_generator(
audio_chunks=audio_chunk_iterator,
streaming_config=s2s_config,
),
sound_stream=sound_stream
)

# if first:
# first = False
# if sound_stream is not None:
# sound_stream(response.audio)



if __name__ == '__main__':
main()