diff --git a/riva/clients/nmt/riva_nmt_t2t_client.cc b/riva/clients/nmt/riva_nmt_t2t_client.cc index 1825405..3721eb6 100644 --- a/riva/clients/nmt/riva_nmt_t2t_client.cc +++ b/riva/clients/nmt/riva_nmt_t2t_client.cc @@ -90,13 +90,31 @@ translateBatch( } } +int countWords(const std::string& text) { + + int wordCount = 0; + bool inside_word = false; + + for (char c : text) { + if (std::isspace(c)) { + inside_word = false; + } else if (!std::ispunct(c)) { + if (!inside_word) { + wordCount++; + inside_word = true; + } + } + } + + return wordCount; +} int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); FLAGS_logtostderr = 1; - + std::stringstream str_usage; str_usage << "Usage: riva_nmt_t2t_client" << std::endl; str_usage << " --text_file= " << std::endl; @@ -195,6 +213,7 @@ main(int argc, char** argv) std::cout << response.translations(0).text() << std::endl; return 0; } + int total_words = 0; if (FLAGS_text_file != "") { // pull strings into vectors per parallel request @@ -220,6 +239,7 @@ main(int argc, char** argv) batch.clear(); } if (!str.empty()) { + total_words += countWords(str); batch.push_back(make_pair(count, str)); count++; } @@ -255,6 +275,7 @@ main(int argc, char** argv) workers.push_back(std::thread([&, i]() { std::unique_ptr nmt2( nr_nmt::RivaTranslation::NewStub(grpc_channel)); + translateBatch( std::move(nmt2), request_queue, FLAGS_target_language_code, FLAGS_source_language_code, FLAGS_model_name, mtx, latencies, lmtx, responses.at(i)); @@ -270,13 +291,14 @@ main(int argc, char** argv) } } } + auto end = std::chrono::steady_clock::now(); std::chrono::duration total = end - start; LOG(INFO) << FLAGS_model_name << "-" << FLAGS_batch_size << "-" << FLAGS_source_language_code << "-" << FLAGS_target_language_code << ",count:" << count << ",total time: " << total.count() << ",requests/second: " << FLAGS_num_iterations * request_count / total.count() - << ",translations/second: " << FLAGS_num_iterations * count / total.count(); + << ",translations/second: " << total_words/total.count(); std::sort(latencies.begin(), latencies.end()); auto size = latencies.size(); diff --git a/riva/clients/nmt/streaming_s2t_client.cc b/riva/clients/nmt/streaming_s2t_client.cc index b62f132..4501738 100644 --- a/riva/clients/nmt/streaming_s2t_client.cc +++ b/riva/clients/nmt/streaming_s2t_client.cc @@ -256,6 +256,8 @@ StreamingS2TClient::PostProcessResults(std::shared_ptr call, bool VLOG(1) << "Latency:" << lat << std::endl; latencies_.push_back(lat); } + std::ofstream result_file(nmt_text_file_, std::ios::app); + call->PrintResult(audio_device, result_file); } void @@ -267,7 +269,6 @@ StreamingS2TClient::ReceiveResponses(std::shared_ptr call, bool a gotoxy(0, 5); } - std::ofstream result_file(nmt_text_file_); while (call->streamer->Read(&call->response)) { // Returns false when no more to read. call->recv_times.push_back(std::chrono::steady_clock::now()); for (int r = 0; r < call->response.results_size(); ++r) { @@ -279,13 +280,9 @@ StreamingS2TClient::ReceiveResponses(std::shared_ptr call, bool a gotoxy(0, 5); } VLOG(1) << "Result: " << result.DebugString(); - std::cout << "translated text: \"" << result.alternatives(0).transcript() << "\"" - << std::endl; - result_file << result.alternatives(0).transcript() << std::endl; } } - result_file.close(); - + PostProcessResults(call, audio_device); grpc::Status status = call->streamer->Finish(); if (!status.ok()) { // Report the RPC failure.