Skip to content

Commit 708e598

Browse files
atomer-nvidiamohnishparmarydharavathmanishaj-nv
authored
Merge release/2.18.0 to main (#111)
* fix: Check for None type custom dict (#106) fix: Check for None type custom dict * Add DNT and custom translation support in NMT client (#108) * changes made for custom translation, dnt phrases * formatted * minor fixes done * changes made * minor fix done * changes done * changes made * changes made for dnt and custom translation * .gitmodules changed --------- Co-authored-by: Manisha Johnson <[email protected]> * fix: undeclared variable (#109) * Updating git SHA to point to TOT main * Updating git SHA to point to TOT main --------- Co-authored-by: mohnishparmar <[email protected]> Co-authored-by: ydharavath <[email protected]> Co-authored-by: Manisha Johnson <[email protected]>
1 parent 6d24b56 commit 708e598

File tree

4 files changed

+58
-6
lines changed

4 files changed

+58
-6
lines changed

common

riva/client/nmt.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def streaming_s2t_request_generator(
2222
for chunk in audio_chunks:
2323
yield riva_nmt.StreamingTranslateSpeechToTextRequest(audio_content=chunk)
2424

25+
def add_dnt_phrases_dict(req, dnt_phrases_dict):
26+
dnt_phrases = None
27+
if dnt_phrases_dict is not None:
28+
dnt_phrases = [f"{key}##{value}" for key, value in dnt_phrases_dict.items()]
29+
if dnt_phrases:
30+
result_dnt_phrases = ",".join(dnt_phrases)
31+
req.dnt_phrases.append(result_dnt_phrases)
32+
2533
class NeuralMachineTranslationClient:
2634
"""
2735
A class for translating text to text. Provides :meth:`translate` which returns translated text
@@ -137,6 +145,7 @@ def translate(
137145
source_language: str,
138146
target_language: str,
139147
future: bool = False,
148+
dnt_phrases_dict: Optional[dict] = None,
140149
) -> Union[riva_nmt.TranslateTextResponse, _MultiThreadedRendezvous]:
141150
"""
142151
Translate input list of input text :param:`text` using model :param:`model` from :param:`source_language` into :param:`target_language`
@@ -158,7 +167,7 @@ def translate(
158167
source_language=source_language,
159168
target_language=target_language
160169
)
161-
170+
add_dnt_phrases_dict(req, dnt_phrases_dict)
162171
func = self.stub.TranslateText.future if future else self.stub.TranslateText
163172
return func(req, metadata=self.auth.get_auth_metadata())
164173

riva/client/tts.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
import wave
1313

1414
def add_custom_dictionary_to_config(req, custom_dictionary):
15-
result_list = [f"{key} {value}" for key, value in custom_dictionary.items()]
16-
result_string = ','.join(result_list)
17-
req.custom_dictionary = result_string
15+
result_list = None
16+
if custom_dictionary is not None:
17+
result_list = [f"{key} {value}" for key, value in custom_dictionary.items()]
18+
if result_list:
19+
result_string = ','.join(result_list)
20+
req.custom_dictionary = result_string
1821

1922
class SpeechSynthesisService:
2023
"""

scripts/nmt/nmt.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,35 @@
3838
from riva.client.argparse_utils import add_connection_argparse_parameters
3939

4040

41+
def read_dnt_phrases_file(file_path):
42+
dnt_phrases_dict = {}
43+
if file_path:
44+
try:
45+
with open(file_path, "r") as infile:
46+
for line in infile:
47+
# Trim leading and trailing whitespaces
48+
line = line.strip()
49+
50+
if line:
51+
pos = line.find("##")
52+
if pos != -1:
53+
# Line contains "##"
54+
key = line[:pos].strip()
55+
value = line[pos + 2 :].strip()
56+
else:
57+
# Line doesn't contain "##"
58+
key = line.strip()
59+
value = ""
60+
61+
# Add the key-value pair to the dictionary
62+
if key:
63+
dnt_phrases_dict[key] = value
64+
65+
except IOError:
66+
raise RuntimeError(f"Could not open file {file_path}")
67+
68+
return dnt_phrases_dict
69+
4170
def parse_args() -> argparse.Namespace:
4271
parser = argparse.ArgumentParser(
4372
description="Neural machine translation by Riva AI Services",
@@ -48,6 +77,7 @@ def parse_args() -> argparse.Namespace:
4877
"--text", default="mir Das ist mir Wurs, bien ich ein berliner", type=str, help="Text to translate"
4978
)
5079
inputs.add_argument("--text-file", type=str, help="Path to file for translation")
80+
parser.add_argument("--dnt-phrases-file", type=str, help="Path to file which contains dnt phrases and custom translations")
5181
parser.add_argument("--model-name", default="", type=str, help="model to use to translate")
5282
parser.add_argument(
5383
"--source-language-code", type=str, default="en-US", help="Source language code (according to BCP-47 standard)"
@@ -65,7 +95,17 @@ def parse_args() -> argparse.Namespace:
6595
def main() -> None:
6696
def request(inputs,args):
6797
try:
68-
response = nmt_client.translate(inputs, args.model_name, args.source_language_code, args.target_language_code)
98+
dnt_phrases_input = {}
99+
if args.dnt_phrases_file != None:
100+
dnt_phrases_input = read_dnt_phrases_file(args.dnt_phrases_file)
101+
response = nmt_client.translate(
102+
texts=inputs,
103+
model=args.model_name,
104+
source_language=args.source_language_code,
105+
target_language=args.target_language_code,
106+
future=False,
107+
dnt_phrases_dict=dnt_phrases_input,
108+
)
69109
for translation in response.translations:
70110
print(translation.text)
71111
except grpc.RpcError as e:

0 commit comments

Comments
 (0)