@@ -29,6 +29,8 @@ namespace nr = nvidia::riva;
2929namespace nr_tts = nvidia::riva::tts;
3030
3131DEFINE_string (text, " " , " Text to be synthesized" );
32+ DEFINE_string (
33+ text_file, " " , " Text file with list of sentences to be synthesized. Ignored if 'text' is set." );
3234DEFINE_string (audio_file, " output.wav" , " Output file" );
3335DEFINE_string (audio_encoding, " pcm" , " Audio encoding (pcm or opus)" );
3436DEFINE_string (riva_uri, " localhost:50051" , " Riva API server URI and port" );
@@ -37,6 +39,7 @@ DEFINE_string(ssl_client_key, "", "Path to SSL client certificates key");
3739DEFINE_string (ssl_client_cert, " " , " Path to SSL client certificates file" );
3840DEFINE_int32 (rate, 44100 , " Sample rate for the TTS output" );
3941DEFINE_bool (online, false , " Whether synthesis should be online or batch" );
42+ DEFINE_bool (streaming, false , " Whether synthesis should be streaming or batch" );
4043DEFINE_string (
4144 language, " en-US" ,
4245 " Language code as per [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) language tag." );
@@ -101,13 +104,15 @@ main(int argc, char** argv)
101104 std::stringstream str_usage;
102105 str_usage << " Usage: riva_tts_client " << std::endl;
103106 str_usage << " --text=<text> " << std::endl;
107+ str_usage << " --text_file=<filename> " << std::endl;
104108 str_usage << " --audio_file=<filename> " << std::endl;
105109 str_usage << " --audio_encoding=<pcm|opus> " << std::endl;
106110 str_usage << " --riva_uri=<server_name:port> " << std::endl;
107111 str_usage << " --rate=<sample_rate> " << std::endl;
108112 str_usage << " --language=<language-code> " << std::endl;
109113 str_usage << " --voice_name=<voice-name> " << std::endl;
110114 str_usage << " --online=<true|false> " << std::endl;
115+ str_usage << " --streaming=<true|false> " << std::endl;
111116 str_usage << " --ssl_root_cert=<filename>" << std::endl;
112117 str_usage << " --ssl_client_key=<filename>" << std::endl;
113118 str_usage << " --ssl_client_cert=<filename>" << std::endl;
@@ -134,10 +139,26 @@ main(int argc, char** argv)
134139 }
135140
136141 auto text = FLAGS_text;
137- if (text.length () == 0 ) {
138- LOG (ERROR) << " Input text cannot be empty." << std::endl;
142+ auto text_file = FLAGS_text_file;
143+ std::vector<std::string> text_lines;
144+ if (text.length () == 0 && text_file.length () == 0 ) {
145+ LOG (ERROR) << " Input text or text file cannot be empty." << std::endl;
146+ return -1 ;
147+ }
148+ if (text.length () > 0 && text_file.length () > 0 ) {
149+ LOG (ERROR) << " Only one of text or text file can be provided." << std::endl;
139150 return -1 ;
140151 }
152+ if (text_file.length () > 0 ) {
153+ std::ifstream infile (text_file);
154+ if (infile.is_open ()) {
155+ std::string line;
156+ while (std::getline (infile, line)) {
157+ text_lines.push_back (line);
158+ text += line + " " ;
159+ }
160+ }
161+ }
141162
142163 bool flag_set = gflags::GetCommandLineFlagInfoOrDie (" riva_uri" ).is_default ;
143164 const char * riva_uri = getenv (" RIVA_URI" );
@@ -152,7 +173,8 @@ main(int argc, char** argv)
152173 auto creds = riva::clients::CreateChannelCredentials (
153174 FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert,
154175 FLAGS_metadata);
155- grpc_channel = riva::clients::CreateChannelBlocking (FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
176+ grpc_channel = riva::clients::CreateChannelBlocking (
177+ FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
156178 }
157179 catch (const std::exception& e) {
158180 std::cerr << " Error creating GRPC channel: " << e.what () << std::endl;
@@ -251,7 +273,7 @@ main(int argc, char** argv)
251273 decoder.DeserializeOpus (std::vector<unsigned char >(ptr, ptr + audio.size ())));
252274 ::riva::utils::wav::Write (FLAGS_audio_file, rate, pcm.data(), pcm.size());
253275 }
254- } else { // online inference
276+ } else if (FLAGS_online && not FLAGS_streaming) { // batch inference
255277 if (not FLAGS_zero_shot_transcript.empty ()) {
256278 LOG (ERROR) << " Zero shot transcript is not supported for streaming inference." ;
257279 return -1 ;
@@ -261,8 +283,11 @@ main(int argc, char** argv)
261283 size_t audio_len = 0 ;
262284 nr_tts::SynthesizeSpeechResponse chunk;
263285 auto start = std::chrono::steady_clock::now ();
264- std::unique_ptr<grpc::ClientReader<nr_tts::SynthesizeSpeechResponse>> reader (
265- tts->SynthesizeOnline (&context, request));
286+ std::unique_ptr<
287+ grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
288+ reader (tts->SynthesizeOnline (&context));
289+ reader->Write (request);
290+ reader->WritesDone ();
266291 while (reader->Read (&chunk)) {
267292 // Copy chunk to local buffer
268293 if (audio_len == 0 ) {
@@ -295,6 +320,61 @@ main(int argc, char** argv)
295320 return -1 ;
296321 }
297322
323+ if (FLAGS_audio_encoding.empty () || FLAGS_audio_encoding == " pcm" ) {
324+ ::riva::utils::wav::Write (FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
325+ } else if (FLAGS_audio_encoding == " opus" ) {
326+ riva::utils::opus::Decoder decoder (rate, 1 );
327+ auto pcm = decoder.DecodePcm (decoder.DeserializeOpus (opus_buffer));
328+ ::riva::utils::wav::Write (FLAGS_audio_file, rate, pcm.data(), pcm.size());
329+ }
330+ } else if (FLAGS_online && FLAGS_streaming) { // streaming inference
331+
332+ std::vector<int16_t > pcm_buffer;
333+ std::vector<unsigned char > opus_buffer;
334+ size_t audio_len = 0 ;
335+ nr_tts::SynthesizeSpeechResponse chunk;
336+ auto start = std::chrono::steady_clock::now ();
337+ std::unique_ptr<
338+ grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
339+ reader (tts->SynthesizeOnline (&context));
340+ for (const auto & line : text_lines) {
341+ request.set_text (line);
342+ reader->Write (request);
343+ }
344+ reader->WritesDone ();
345+ while (reader->Read (&chunk)) {
346+ // Copy chunk to local buffer
347+ if (audio_len == 0 ) {
348+ auto t_first_audio = std::chrono::steady_clock::now ();
349+ std::chrono::duration<double > elapsed_first_audio = t_first_audio - start;
350+ LOG (INFO) << " Time to first chunk: " << elapsed_first_audio.count () << " s" << std::endl;
351+ }
352+ LOG (INFO) << " Got chunk: " << chunk.audio ().size () << " bytes" << std::endl;
353+ if (FLAGS_audio_encoding.empty () || FLAGS_audio_encoding == " pcm" ) {
354+ int16_t * audio_data = (int16_t *)chunk.audio ().data ();
355+ size_t len = chunk.audio ().length () / sizeof (int16_t );
356+ std::copy (audio_data, audio_data + len, std::back_inserter (pcm_buffer));
357+ audio_len += len;
358+ } else if (FLAGS_audio_encoding == " opus" ) {
359+ const unsigned char * opus_data = (unsigned char *)chunk.audio ().data ();
360+ size_t len = chunk.audio ().length ();
361+ std::copy (opus_data, opus_data + len, std::back_inserter (opus_buffer));
362+ audio_len += len;
363+ }
364+ }
365+ grpc::Status rpc_status = reader->Finish ();
366+ auto end = std::chrono::steady_clock::now ();
367+ std::chrono::duration<double > elapsed_total = end - start;
368+ LOG (INFO) << " Total streaming time: " << elapsed_total.count () << " s" << std::endl;
369+
370+ if (!rpc_status.ok ()) {
371+ // Report the RPC failure.
372+ LOG (ERROR) << rpc_status.error_message () << std::endl;
373+ LOG (ERROR) << " Input was: " << text_lines.size () << " lines." << std::endl;
374+ LOG (ERROR) << " Input was: \' " << text << " \' " << std::endl;
375+ return -1 ;
376+ }
377+
298378 if (FLAGS_audio_encoding.empty () || FLAGS_audio_encoding == " pcm" ) {
299379 ::riva::utils::wav::Write (FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
300380 } else if (FLAGS_audio_encoding == " opus" ) {
0 commit comments