diff --git a/riva/clients/tts/riva_tts_client.cc b/riva/clients/tts/riva_tts_client.cc index 05d9f94..eb6256b 100644 --- a/riva/clients/tts/riva_tts_client.cc +++ b/riva/clients/tts/riva_tts_client.cc @@ -29,6 +29,8 @@ namespace nr = nvidia::riva; namespace nr_tts = nvidia::riva::tts; DEFINE_string(text, "", "Text to be synthesized"); +DEFINE_string( + text_file, "", "Text file with list of sentences to be synthesized. Ignored if 'text' is set."); DEFINE_string(audio_file, "output.wav", "Output file"); DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)"); DEFINE_string(riva_uri, "localhost:50051", "Riva API server URI and port"); @@ -37,6 +39,7 @@ DEFINE_string(ssl_client_key, "", "Path to SSL client certificates key"); DEFINE_string(ssl_client_cert, "", "Path to SSL client certificates file"); DEFINE_int32(rate, 44100, "Sample rate for the TTS output"); DEFINE_bool(online, false, "Whether synthesis should be online or batch"); +DEFINE_bool(streaming, false, "Whether synthesis should be streaming or batch"); DEFINE_string( language, "en-US", "Language code as per [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) language tag."); @@ -101,6 +104,7 @@ main(int argc, char** argv) std::stringstream str_usage; str_usage << "Usage: riva_tts_client " << std::endl; str_usage << " --text= " << std::endl; + str_usage << " --text_file= " << std::endl; str_usage << " --audio_file= " << std::endl; str_usage << " --audio_encoding= " << std::endl; str_usage << " --riva_uri= " << std::endl; @@ -108,6 +112,7 @@ main(int argc, char** argv) str_usage << " --language= " << std::endl; str_usage << " --voice_name= " << std::endl; str_usage << " --online= " << std::endl; + str_usage << " --streaming= " << std::endl; str_usage << " --ssl_root_cert=" << std::endl; str_usage << " --ssl_client_key=" << std::endl; str_usage << " --ssl_client_cert=" << std::endl; @@ -134,10 +139,26 @@ main(int argc, char** argv) } auto text = FLAGS_text; - if (text.length() == 0) { - LOG(ERROR) << "Input text cannot be empty." << std::endl; + auto text_file = FLAGS_text_file; + std::vector text_lines; + if (text.length() == 0 && text_file.length() == 0) { + LOG(ERROR) << "Input text or text file cannot be empty." << std::endl; + return -1; + } + if (text.length() > 0 && text_file.length() > 0) { + LOG(ERROR) << "Only one of text or text file can be provided." << std::endl; return -1; } + if (text_file.length() > 0) { + std::ifstream infile(text_file); + if (infile.is_open()) { + std::string line; + while (std::getline(infile, line)) { + text_lines.push_back(line); + text += line + " "; + } + } + } bool flag_set = gflags::GetCommandLineFlagInfoOrDie("riva_uri").is_default; const char* riva_uri = getenv("RIVA_URI"); @@ -152,7 +173,8 @@ main(int argc, char** argv) auto creds = riva::clients::CreateChannelCredentials( FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, FLAGS_metadata); - grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size); + grpc_channel = riva::clients::CreateChannelBlocking( + FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size); } catch (const std::exception& e) { std::cerr << "Error creating GRPC channel: " << e.what() << std::endl; @@ -251,7 +273,7 @@ main(int argc, char** argv) decoder.DeserializeOpus(std::vector(ptr, ptr + audio.size()))); ::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size()); } - } else { // online inference + } else if (FLAGS_online && not FLAGS_streaming) { // batch inference if (not FLAGS_zero_shot_transcript.empty()) { LOG(ERROR) << "Zero shot transcript is not supported for streaming inference."; return -1; @@ -261,8 +283,11 @@ main(int argc, char** argv) size_t audio_len = 0; nr_tts::SynthesizeSpeechResponse chunk; auto start = std::chrono::steady_clock::now(); - std::unique_ptr> reader( - tts->SynthesizeOnline(&context, request)); + std::unique_ptr< + grpc::ClientReaderWriter> + reader(tts->SynthesizeOnline(&context)); + reader->Write(request); + reader->WritesDone(); while (reader->Read(&chunk)) { // Copy chunk to local buffer if (audio_len == 0) { @@ -295,6 +320,65 @@ main(int argc, char** argv) return -1; } + if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { + ::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size()); + } else if (FLAGS_audio_encoding == "opus") { + riva::utils::opus::Decoder decoder(rate, 1); + auto pcm = decoder.DecodePcm(decoder.DeserializeOpus(opus_buffer)); + ::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size()); + } + } else if (FLAGS_online && FLAGS_streaming) { // streaming inference + + std::vector pcm_buffer; + std::vector opus_buffer; + size_t audio_len = 0; + nr_tts::SynthesizeSpeechResponse chunk; + auto start = std::chrono::steady_clock::now(); + std::unique_ptr< + grpc::ClientReaderWriter> + reader(tts->SynthesizeOnline(&context)); + for (const auto& line : text_lines) { + if (line.find("|") != std::string::npos) { + request.set_text(line.substr(line.find("|") + 1, line.length())); + } else { + request.set_text(line); + } + reader->Write(request); + } + reader->WritesDone(); + while (reader->Read(&chunk)) { + // Copy chunk to local buffer + if (audio_len == 0) { + auto t_first_audio = std::chrono::steady_clock::now(); + std::chrono::duration elapsed_first_audio = t_first_audio - start; + LOG(INFO) << "Time to first chunk: " << elapsed_first_audio.count() << " s" << std::endl; + } + LOG(INFO) << "Got chunk: " << chunk.audio().size() << " bytes" << std::endl; + if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { + int16_t* audio_data = (int16_t*)chunk.audio().data(); + size_t len = chunk.audio().length() / sizeof(int16_t); + std::copy(audio_data, audio_data + len, std::back_inserter(pcm_buffer)); + audio_len += len; + } else if (FLAGS_audio_encoding == "opus") { + const unsigned char* opus_data = (unsigned char*)chunk.audio().data(); + size_t len = chunk.audio().length(); + std::copy(opus_data, opus_data + len, std::back_inserter(opus_buffer)); + audio_len += len; + } + } + grpc::Status rpc_status = reader->Finish(); + auto end = std::chrono::steady_clock::now(); + std::chrono::duration elapsed_total = end - start; + LOG(INFO) << "Total streaming time: " << elapsed_total.count() << " s" << std::endl; + + if (!rpc_status.ok()) { + // Report the RPC failure. + LOG(ERROR) << rpc_status.error_message() << std::endl; + LOG(ERROR) << "Input was: " << text_lines.size() << " lines." << std::endl; + LOG(ERROR) << "Input was: \'" << text << "\'" << std::endl; + return -1; + } + if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { ::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size()); } else if (FLAGS_audio_encoding == "opus") { diff --git a/riva/clients/tts/riva_tts_perf_client.cc b/riva/clients/tts/riva_tts_perf_client.cc index 0f15e6e..50b9b0e 100644 --- a/riva/clients/tts/riva_tts_perf_client.cc +++ b/riva/clients/tts/riva_tts_perf_client.cc @@ -38,6 +38,7 @@ DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)"); DEFINE_string(riva_uri, "localhost:50051", "Riva API server URI and port"); DEFINE_int32(rate, 44100, "Sample rate for the TTS output"); DEFINE_bool(online, false, "Whether synthesis should be online or batch"); +DEFINE_bool(streaming, false, "Whether synthesis should be streaming input"); DEFINE_bool( write_output_audio, false, "Whether to dump output audio or not. When true, throughput and latency are not reported."); @@ -47,6 +48,7 @@ DEFINE_string( DEFINE_string(voice_name, "", "Desired voice name"); DEFINE_int32(num_iterations, 1, "Number of times to loop over audio files"); DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in flight"); +DEFINE_int32(num_sentences, 1, "Number of sentences to send"); DEFINE_int32(throttle_milliseconds, 0, "Number of milliseconds to sleep for between TTS requests"); DEFINE_int32(offset_milliseconds, 0, "Number of milliseconds to offset each parallel TTS requests"); DEFINE_string(ssl_root_cert, "", "Path to SSL root certificates file"); @@ -260,8 +262,11 @@ synthesizeOnline( nr_tts::SynthesizeSpeechResponse chunk; auto start = std::chrono::steady_clock::now(); - std::unique_ptr> reader( - tts->SynthesizeOnline(&context, request)); + std::unique_ptr< + grpc::ClientReaderWriter> + reader(tts->SynthesizeOnline(&context)); + reader->Write(request); + reader->WritesDone(); DLOG(INFO) << "Sending request for input \"" << text << "\"."; std::vector buffer; @@ -344,9 +349,11 @@ main(int argc, char** argv) str_usage << " --language= " << std::endl; str_usage << " --voice_name= " << std::endl; str_usage << " --online= " << std::endl; + str_usage << " --streaming= " << std::endl; str_usage << " --audio_encoding= " << std::endl; str_usage << " --num_parallel_requests= " << std::endl; str_usage << " --num_iterations= " << std::endl; + str_usage << " --num_sentences= " << std::endl; str_usage << " --throttle_milliseconds= " << std::endl; str_usage << " --offset_milliseconds= " << std::endl; str_usage << " --ssl_root_cert=" << std::endl; @@ -404,7 +411,7 @@ main(int argc, char** argv) // open text file, load sentences as a vector int count = 0; - for (int i = 0; i < FLAGS_num_iterations; i++) { + for (int i = 0; i < FLAGS_num_iterations * FLAGS_num_sentences; i++) { std::ifstream file(text_file); while (std::getline(file, sentence)) { if (sentence.find("|") != std::string::npos) { @@ -458,38 +465,156 @@ main(int argc, char** argv) usleep(i * FLAGS_offset_milliseconds * 1000); auto start_time = std::chrono::steady_clock::now(); - for (size_t s = 0; s < sentences[i].size(); s++) { - auto current_time = std::chrono::steady_clock::now(); - double diff_time = - std::chrono::duration(current_time - start_time).count(); - double wait_time = (s + 1) * FLAGS_throttle_milliseconds - diff_time; - - // To nanoseconds - wait_time *= 1.e3; - wait_time = std::max(wait_time, 0.); - - // Round to nearest integer - wait_time = wait_time + 0.5 - (wait_time < 0); - int64_t usecs = (int64_t)wait_time; - // Sleep - if (usecs > 0) { - usleep(usecs); + if (FLAGS_streaming) { + // Streaming mode: send all sentences in one stream + auto tts = CreateTTS(grpc_channel); + + nr_tts::SynthesizeSpeechRequest request; + request.set_language_code(FLAGS_language); + request.set_sample_rate_hz(rate); + request.set_voice_name(FLAGS_voice_name); + + auto ae = nr::AudioEncoding::ENCODING_UNSPECIFIED; + if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { + ae = nr::LINEAR_PCM; + } else if (FLAGS_audio_encoding == "opus") { + ae = nr::OGGOPUS; + } else { + std::cerr << "Unsupported encoding: \'" << FLAGS_audio_encoding << "\'" << std::endl; + return; + } + request.set_encoding(ae); + + if (not FLAGS_zero_shot_audio_prompt.empty()) { + auto zero_shot_data = request.mutable_zero_shot_data(); + std::vector> audio_prompt; + LoadWavData(audio_prompt, FLAGS_zero_shot_audio_prompt); + if (audio_prompt.size() != 1) { + LOG(ERROR) << "Unsupported number of audio prompts. Need exactly 1 audio prompt." + << std::endl; + return; + } + + if (audio_prompt[0]->encoding != nr::LINEAR_PCM && + audio_prompt[0]->encoding != nr::OGGOPUS) { + LOG(ERROR) << "Unsupported encoding for zero shot prompt: \'" + << audio_prompt[0]->encoding << "\'"; + std::cerr << "Unsupported encoding for zero shot prompt: \'" + << audio_prompt[0]->encoding << "\'" << std::endl; + return; + } + zero_shot_data->set_audio_prompt( + &audio_prompt[0]->data[0], audio_prompt[0]->data.size()); + int32_t zero_shot_sample_rate = audio_prompt[0]->sample_rate; + zero_shot_data->set_encoding(audio_prompt[0]->encoding); + if (audio_prompt[0]->encoding == nr::OGGOPUS) { + zero_shot_sample_rate = + riva::utils::opus::Decoder::AdjustRateIfUnsupported(zero_shot_sample_rate); + } + zero_shot_data->set_sample_rate_hz(zero_shot_sample_rate); + zero_shot_data->set_quality(FLAGS_zero_shot_quality); } - auto tts = CreateTTS(grpc_channel); - double time_to_first_chunk = 0.; - auto time_to_next_chunk = new std::vector(); - size_t num_samples = 0; - synthesizeOnline( - std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, - &time_to_first_chunk, time_to_next_chunk, &num_samples, - std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality); - latencies_first_chunk[i]->push_back(time_to_first_chunk); - latencies_next_chunks[i]->insert( - latencies_next_chunks[i]->end(), time_to_next_chunk->begin(), - time_to_next_chunk->end()); - lengths[i]->push_back(num_samples); + grpc::ClientContext context; + nr_tts::SynthesizeSpeechResponse chunk; + auto stream_start = std::chrono::steady_clock::now(); + + std::unique_ptr> + reader(tts->SynthesizeOnline(&context)); + + // Write all sentences to the stream + for (size_t s = 0; s < sentences[i].size(); s++) { + request.set_text(sentences[i][s].second); + reader->Write(request); + } + reader->WritesDone(); + + std::vector buffer; + size_t audio_len = 0; + riva::utils::opus::Decoder opus_decoder(rate, 1); + + // Read all responses + while (reader->Read(&chunk)) { + size_t len = 0U; + if (ae == nr::OGGOPUS) { + const unsigned char* opus_data = (unsigned char*)chunk.audio().data(); + len = chunk.audio().length(); + auto pcm = opus_decoder.DecodePcm(opus_decoder.DeserializeOpus( + std::vector(opus_data, opus_data + len))); + len = pcm.size(); + std::copy(pcm.cbegin(), pcm.cend(), std::back_inserter(buffer)); + } else { + int16_t* audio_data; + audio_data = (int16_t*)chunk.audio().data(); + len = chunk.audio().length() / sizeof(int16_t); + std::copy(audio_data, audio_data + len, std::back_inserter(buffer)); + } + + if (audio_len == 0) { + auto t_next_audio = std::chrono::steady_clock::now(); + std::chrono::duration elapsed_first_audio = t_next_audio - stream_start; + latencies_first_chunk[i]->push_back(elapsed_first_audio.count()); + stream_start = t_next_audio; + } else { + auto t_next_audio = std::chrono::steady_clock::now(); + std::chrono::duration elapsed_next_audio = t_next_audio - stream_start; + time_to_next_chunks->push_back(elapsed_next_audio.count()); + stream_start = t_next_audio; + } + audio_len += len; + } + + grpc::Status rpc_status = reader->Finish(); + + if (!rpc_status.ok()) { + std::cerr << rpc_status.error_message() << std::endl; + std::cerr << "Streaming input failed for worker " << i << std::endl; + } else { + lengths[i]->push_back(audio_len); + latencies_next_chunks[i]->insert( + latencies_next_chunks[i]->end(), time_to_next_chunks->begin(), + time_to_next_chunks->end()); + if (FLAGS_write_output_audio) { + ::riva::utils::wav::Write( + "worker_" + std::to_string(i) + ".wav", rate, buffer.data(), buffer.size()); + } + } + } else { + // Non-streaming mode: send one sentence per stream + for (size_t s = 0; s < sentences[i].size(); s++) { + auto current_time = std::chrono::steady_clock::now(); + double diff_time = + std::chrono::duration(current_time - start_time).count(); + double wait_time = (s + 1) * FLAGS_throttle_milliseconds - diff_time; + + // To nanoseconds + wait_time *= 1.e3; + wait_time = std::max(wait_time, 0.); + + // Round to nearest integer + wait_time = wait_time + 0.5 - (wait_time < 0); + int64_t usecs = (int64_t)wait_time; + // Sleep + if (usecs > 0) { + usleep(usecs); + } + + auto tts = CreateTTS(grpc_channel); + double time_to_first_chunk = 0.; + auto time_to_next_chunk = new std::vector(); + size_t num_samples = 0; + synthesizeOnline( + std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, + &time_to_first_chunk, time_to_next_chunk, &num_samples, + std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, + FLAGS_zero_shot_quality); + latencies_first_chunk[i]->push_back(time_to_first_chunk); + latencies_next_chunks[i]->insert( + latencies_next_chunks[i]->end(), time_to_next_chunk->begin(), + time_to_next_chunk->end()); + lengths[i]->push_back(num_samples); + } } })); } @@ -588,4 +713,3 @@ main(int argc, char** argv) } return STATUS; } -