Skip to content

Commit 53e37f7

Browse files
committed
chore: Update TTS client
1 parent a543e92 commit 53e37f7

File tree

1 file changed

+86
-6
lines changed

1 file changed

+86
-6
lines changed

riva/clients/tts/riva_tts_client.cc

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ namespace nr = nvidia::riva;
2929
namespace nr_tts = nvidia::riva::tts;
3030

3131
DEFINE_string(text, "", "Text to be synthesized");
32+
DEFINE_string(
33+
text_file, "", "Text file with list of sentences to be synthesized. Ignored if 'text' is set.");
3234
DEFINE_string(audio_file, "output.wav", "Output file");
3335
DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)");
3436
DEFINE_string(riva_uri, "localhost:50051", "Riva API server URI and port");
@@ -37,6 +39,7 @@ DEFINE_string(ssl_client_key, "", "Path to SSL client certificates key");
3739
DEFINE_string(ssl_client_cert, "", "Path to SSL client certificates file");
3840
DEFINE_int32(rate, 44100, "Sample rate for the TTS output");
3941
DEFINE_bool(online, false, "Whether synthesis should be online or batch");
42+
DEFINE_bool(streaming, false, "Whether synthesis should be streaming or batch");
4043
DEFINE_string(
4144
language, "en-US",
4245
"Language code as per [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) language tag.");
@@ -101,13 +104,15 @@ main(int argc, char** argv)
101104
std::stringstream str_usage;
102105
str_usage << "Usage: riva_tts_client " << std::endl;
103106
str_usage << " --text=<text> " << std::endl;
107+
str_usage << " --text_file=<filename> " << std::endl;
104108
str_usage << " --audio_file=<filename> " << std::endl;
105109
str_usage << " --audio_encoding=<pcm|opus> " << std::endl;
106110
str_usage << " --riva_uri=<server_name:port> " << std::endl;
107111
str_usage << " --rate=<sample_rate> " << std::endl;
108112
str_usage << " --language=<language-code> " << std::endl;
109113
str_usage << " --voice_name=<voice-name> " << std::endl;
110114
str_usage << " --online=<true|false> " << std::endl;
115+
str_usage << " --streaming=<true|false> " << std::endl;
111116
str_usage << " --ssl_root_cert=<filename>" << std::endl;
112117
str_usage << " --ssl_client_key=<filename>" << std::endl;
113118
str_usage << " --ssl_client_cert=<filename>" << std::endl;
@@ -134,10 +139,26 @@ main(int argc, char** argv)
134139
}
135140

136141
auto text = FLAGS_text;
137-
if (text.length() == 0) {
138-
LOG(ERROR) << "Input text cannot be empty." << std::endl;
142+
auto text_file = FLAGS_text_file;
143+
std::vector<std::string> text_lines;
144+
if (text.length() == 0 && text_file.length() == 0) {
145+
LOG(ERROR) << "Input text or text file cannot be empty." << std::endl;
146+
return -1;
147+
}
148+
if (text.length() > 0 && text_file.length() > 0) {
149+
LOG(ERROR) << "Only one of text or text file can be provided." << std::endl;
139150
return -1;
140151
}
152+
if (text_file.length() > 0) {
153+
std::ifstream infile(text_file);
154+
if (infile.is_open()) {
155+
std::string line;
156+
while (std::getline(infile, line)) {
157+
text_lines.push_back(line);
158+
text += line + " ";
159+
}
160+
}
161+
}
141162

142163
bool flag_set = gflags::GetCommandLineFlagInfoOrDie("riva_uri").is_default;
143164
const char* riva_uri = getenv("RIVA_URI");
@@ -152,7 +173,8 @@ main(int argc, char** argv)
152173
auto creds = riva::clients::CreateChannelCredentials(
153174
FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert,
154175
FLAGS_metadata);
155-
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
176+
grpc_channel = riva::clients::CreateChannelBlocking(
177+
FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
156178
}
157179
catch (const std::exception& e) {
158180
std::cerr << "Error creating GRPC channel: " << e.what() << std::endl;
@@ -251,7 +273,7 @@ main(int argc, char** argv)
251273
decoder.DeserializeOpus(std::vector<unsigned char>(ptr, ptr + audio.size())));
252274
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size());
253275
}
254-
} else { // online inference
276+
} else if (FLAGS_online && not FLAGS_streaming) { // batch inference
255277
if (not FLAGS_zero_shot_transcript.empty()) {
256278
LOG(ERROR) << "Zero shot transcript is not supported for streaming inference.";
257279
return -1;
@@ -261,8 +283,11 @@ main(int argc, char** argv)
261283
size_t audio_len = 0;
262284
nr_tts::SynthesizeSpeechResponse chunk;
263285
auto start = std::chrono::steady_clock::now();
264-
std::unique_ptr<grpc::ClientReader<nr_tts::SynthesizeSpeechResponse>> reader(
265-
tts->SynthesizeOnline(&context, request));
286+
std::unique_ptr<
287+
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
288+
reader(tts->SynthesizeOnline(&context));
289+
reader->Write(request);
290+
reader->WritesDone();
266291
while (reader->Read(&chunk)) {
267292
// Copy chunk to local buffer
268293
if (audio_len == 0) {
@@ -295,6 +320,61 @@ main(int argc, char** argv)
295320
return -1;
296321
}
297322

323+
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
324+
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
325+
} else if (FLAGS_audio_encoding == "opus") {
326+
riva::utils::opus::Decoder decoder(rate, 1);
327+
auto pcm = decoder.DecodePcm(decoder.DeserializeOpus(opus_buffer));
328+
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size());
329+
}
330+
} else if (FLAGS_online && FLAGS_streaming) { // streaming inference
331+
332+
std::vector<int16_t> pcm_buffer;
333+
std::vector<unsigned char> opus_buffer;
334+
size_t audio_len = 0;
335+
nr_tts::SynthesizeSpeechResponse chunk;
336+
auto start = std::chrono::steady_clock::now();
337+
std::unique_ptr<
338+
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
339+
reader(tts->SynthesizeOnline(&context));
340+
for (const auto& line : text_lines) {
341+
request.set_text(line);
342+
reader->Write(request);
343+
}
344+
reader->WritesDone();
345+
while (reader->Read(&chunk)) {
346+
// Copy chunk to local buffer
347+
if (audio_len == 0) {
348+
auto t_first_audio = std::chrono::steady_clock::now();
349+
std::chrono::duration<double> elapsed_first_audio = t_first_audio - start;
350+
LOG(INFO) << "Time to first chunk: " << elapsed_first_audio.count() << " s" << std::endl;
351+
}
352+
LOG(INFO) << "Got chunk: " << chunk.audio().size() << " bytes" << std::endl;
353+
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
354+
int16_t* audio_data = (int16_t*)chunk.audio().data();
355+
size_t len = chunk.audio().length() / sizeof(int16_t);
356+
std::copy(audio_data, audio_data + len, std::back_inserter(pcm_buffer));
357+
audio_len += len;
358+
} else if (FLAGS_audio_encoding == "opus") {
359+
const unsigned char* opus_data = (unsigned char*)chunk.audio().data();
360+
size_t len = chunk.audio().length();
361+
std::copy(opus_data, opus_data + len, std::back_inserter(opus_buffer));
362+
audio_len += len;
363+
}
364+
}
365+
grpc::Status rpc_status = reader->Finish();
366+
auto end = std::chrono::steady_clock::now();
367+
std::chrono::duration<double> elapsed_total = end - start;
368+
LOG(INFO) << "Total streaming time: " << elapsed_total.count() << " s" << std::endl;
369+
370+
if (!rpc_status.ok()) {
371+
// Report the RPC failure.
372+
LOG(ERROR) << rpc_status.error_message() << std::endl;
373+
LOG(ERROR) << "Input was: " << text_lines.size() << " lines." << std::endl;
374+
LOG(ERROR) << "Input was: \'" << text << "\'" << std::endl;
375+
return -1;
376+
}
377+
298378
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
299379
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
300380
} else if (FLAGS_audio_encoding == "opus") {

0 commit comments

Comments
 (0)