Skip to content

Commit e2ebba3

Browse files
junkinPeganovAnton
andauthored
feat(nmt): add nmt client (#18)
* add nmt files Update __init__.py * bump patch version * update submodule * proto location change * chore: Roll version numbers * fix(nmt): Match existing style * Update riva/client/nmt.py Co-authored-by: PeganovAnton <[email protected]> * Update scripts/nmt/nmt.py Co-authored-by: PeganovAnton <[email protected]> * Update scripts/nmt/nmt.py Co-authored-by: PeganovAnton <[email protected]> Co-authored-by: PeganovAnton <[email protected]>
1 parent ad71e0e commit e2ebba3

File tree

6 files changed

+188
-7
lines changed

6 files changed

+188
-7
lines changed

riva/client/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@
3535
from riva.client.proto.riva_asr_pb2 import RecognitionConfig, StreamingRecognitionConfig
3636
from riva.client.proto.riva_audio_pb2 import AudioEncoding
3737
from riva.client.proto.riva_nlp_pb2 import AnalyzeIntentOptions
38-
from riva.client.tts import SpeechSynthesisService
38+
from riva.client.tts import SpeechSynthesisService
39+
from riva.client.nmt import NeuralMachineTranslationClient

riva/client/nmt.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
from typing import Generator, Optional, Union, List
5+
6+
from grpc._channel import _MultiThreadedRendezvous
7+
8+
import riva.client.proto.riva_nmt_pb2 as riva_nmt
9+
import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv
10+
from riva.client import Auth
11+
12+
13+
class NeuralMachineTranslationClient:
14+
"""
15+
A class for translating text to text. Provides :meth:`translate` which returns translated text
16+
"""
17+
def __init__(self, auth: Auth) -> None:
18+
"""
19+
Initializes an instance of the class.
20+
21+
Args:
22+
auth (:obj:`Auth`): an instance of :class:`riva.client.auth.Auth` which is used for authentication metadata
23+
generation.
24+
"""
25+
self.auth = auth
26+
self.stub = riva_nmt_srv.RivaTranslationStub(self.auth.channel)
27+
28+
def translate(
29+
self,
30+
texts: List[str],
31+
model: str,
32+
source_language: str,
33+
target_language: str,
34+
future: bool = False,
35+
) -> Union[riva_nmt.TranslateTextResponse, _MultiThreadedRendezvous]:
36+
"""
37+
Translate input list of input text :param:`text` using model :param:`model` from :param:`source_language` into :param:`target_language`
38+
39+
Args:
40+
text (:obj:`list[str]`): input text.
41+
future (:obj:`bool`, defaults to :obj:`False`): whether to return an async result instead of usual
42+
response. You can get a response by calling ``result()`` method of the future object.
43+
44+
Returns:
45+
:obj:`Union[riva.client.proto.riva_nmt_pb2.TranslateTextResponse, grpc._channel._MultiThreadedRendezvous]`:
46+
a response with output. You may find :class:`riva.client.proto.riva_nmt_pb2.TranslateTextResponse` fields
47+
description `here
48+
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
49+
"""
50+
req = riva_nmt.TranslateTextRequest(
51+
texts=texts,
52+
model=model,
53+
source_language=source_language,
54+
target_language=target_language
55+
)
56+
57+
func = self.stub.TranslateText.future if future else self.stub.TranslateText
58+
return func(req, metadata=self.auth.get_auth_metadata())
59+
60+
def get_config(
61+
self,
62+
model: str,
63+
future: bool = False,
64+
) -> Union[riva_nmt.AvailableLanguageResponse, _MultiThreadedRendezvous]:
65+
req = riva_nmt.AvailableLanguageRequest(model=model)
66+
func = self.stub.ListSupportedLanguagePairs.future if future else self.stub.ListSupportedLanguagePairs
67+
return func(req, metadata=self.auth.get_auth_metadata())

riva/client/package_info.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
MAJOR = 0
55
MINOR = 0
6-
PATCH = 5
6+
PATCH = 6
77
PRE_RELEASE = 'rc0'
88

99
# Use the following formatting: (major, minor, patch, pre-release)
@@ -20,7 +20,7 @@
2020
__download_url__ = 'hhttps://github.com/nvidia-riva/python-clients/releases'
2121
__description__ = "Python implementation of the Riva Client API"
2222
__license__ = 'MIT'
23-
__keywords__ = 'deep learning, machine learning, gpu, NLP, ASR, TTS, nvidia, speech, language, Riva, client'
24-
__riva_version__ = "2.3.0"
25-
__riva_release__ = "22.06"
26-
__riva_models_version__ = "2.3.0"
23+
__keywords__ = 'deep learning, machine learning, gpu, NLP, ASR, TTS, NMT, nvidia, speech, language, Riva, client'
24+
__riva_version__ = "2.7.0"
25+
__riva_release__ = "22.10"
26+
__riva_models_version__ = "2.7.0"

scripts/nmt/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#q SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: MIT

scripts/nmt/nmt.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
#!/usr/bin/env python
28+
29+
import argparse
30+
import os
31+
import sys
32+
33+
import grpc
34+
import riva.client.proto.riva_nmt_pb2 as riva_nmt
35+
import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv
36+
37+
import riva.client
38+
from riva.client.argparse_utils import add_connection_argparse_parameters
39+
40+
41+
def parse_args() -> argparse.Namespace:
42+
parser = argparse.ArgumentParser(
43+
description="Neural machine translation by Riva AI Services",
44+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
45+
)
46+
inputs = parser.add_mutually_exclusive_group()
47+
inputs.add_argument(
48+
"--text", default="mir Das ist mir Wurs, bien ich ein berliner", type=str, help="Text to translate"
49+
)
50+
inputs.add_argument("--text-file", type=str, help="Path to file for translation")
51+
parser.add_argument("--model-name", default="riva-nmt", type=str, help="model to use to translate")
52+
parser.add_argument(
53+
"--src-language", type=str, help="Source language (according to BCP-47 standard)"
54+
)
55+
parser.add_argument(
56+
"--tgt-language", type=str, help="Target language (according to BCP-47 standard)"
57+
)
58+
parser.add_argument("--text-file", type=str, help="Path to file for translation")
59+
parser.add_argument("--batch-size", type=int, default=8, help="Batch size to use for file translation")
60+
parser.add_argument("--list-models", default=False, action='store_true', help="List available models")
61+
parser = add_connection_argparse_parameters(parser)
62+
63+
return parser.parse_args()
64+
65+
66+
def main() -> None:
67+
def request(inputs,args):
68+
try:
69+
response = nmt_client.translate(inputs, args.model_name, args.src_language, args.tgt_language)
70+
for translation in response.translations:
71+
print(translation.text)
72+
except grpc.RpcError as e:
73+
if e.code() == grpc.StatusCode.INVALID_ARGUMENT:
74+
result = {'msg': 'invalid arg error'}
75+
elif e.code() == grpc.StatusCode.ALREADY_EXISTS:
76+
result = {'msg': 'already exists error'}
77+
elif e.code() == grpc.StatusCode.UNAVAILABLE:
78+
result = {'msg': 'server unavailable check network'}
79+
print(f"{result['msg']} : {e.details()}")
80+
81+
args = parse_args()
82+
83+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
84+
nmt_client = riva.client.NeuralMachineTranslationClient(auth)
85+
86+
if args.list_models:
87+
88+
response = nmt_client.get_config(args.model_name)
89+
print(response)
90+
return
91+
92+
if args.text_file != None and os.path.exists(args.text_file):
93+
with open(args.text_file, "r") as f:
94+
batch = []
95+
for line in f:
96+
line = line.strip()
97+
if line != "":
98+
batch.append(line)
99+
if len(batch) == args.batch_size:
100+
request(batch, args)
101+
batch = []
102+
if len(batch) > 0:
103+
request(batch, args)
104+
return
105+
106+
if args.text != "":
107+
request([args.text], args)
108+
109+
110+
if __name__ == '__main__':
111+
main()

0 commit comments

Comments
 (0)