diff --git a/runtime/onnxruntime/include/audio.h b/runtime/onnxruntime/include/audio.h index 3011050ed..b14d3b85d 100644 --- a/runtime/onnxruntime/include/audio.h +++ b/runtime/onnxruntime/include/audio.h @@ -100,6 +100,8 @@ class DLLAPI Audio { int offset = 0; int speech_start=-1, speech_end=0; int speech_offline_start=-1; + int64_t start = 0; + int64_t end = 0; int seg_sample = MODEL_SAMPLE_RATE/1000; bool LoadPcmwavOnline(const char* buf, int n_file_len, int32_t* sampling_rate); diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h index 685c0241f..1a3cff607 100644 --- a/runtime/onnxruntime/include/funasrruntime.h +++ b/runtime/onnxruntime/include/funasrruntime.h @@ -70,6 +70,8 @@ _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index); _FUNASRAPI const char* FunASRGetStamp(FUNASR_RESULT result); _FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result); _FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index); +_FUNASRAPI const int64_t FunASRGetTpassStart(FUNASR_RESULT result); +_FUNASRAPI const int64_t FunASRGetTpassEnd(FUNASR_RESULT result); _FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result); _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result); _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle); diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp index 22a9ecd29..4ceb56a3a 100644 --- a/runtime/onnxruntime/src/audio.cpp +++ b/runtime/onnxruntime/src/audio.cpp @@ -1289,6 +1289,10 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP } } }else{ + + int sample_rate = 16000; // sample_rate 是音频的采样率 这里固定为16000 Hz + float segment_duration = (static_cast(seg_sample) / sample_rate) * 1000; // 每个分段的持续时间(毫秒) + for(auto vad_segment: vad_segments){ int speech_start_i=-1, speech_end_i=-1; if(vad_segment[0] != -1){ @@ -1325,6 +1329,12 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP frame = nullptr; } + //设置开始时间和结束时间 + float start_time = speech_start_i * segment_duration; // 开始时间(毫秒) + float end_time = speech_end_i * segment_duration; // 结束时间(毫秒) + // 转换为 int64_t 类型并赋值给类的成员变量 + this->start = static_cast(start_time); + this->end = static_cast(end_time); speech_start = -1; speech_offline_start = -1; // [70, -1] @@ -1350,6 +1360,8 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP } } + float start_time = speech_start_i * segment_duration; // 仅有开始时间 + this->start = static_cast(start_time); }else if(speech_end_i != -1){ // [-1,100] if(speech_start == -1 || speech_offline_start == -1){ LOG(ERROR) <<"Vad start is null while vad end is available. Set vad start 0" ; @@ -1399,6 +1411,8 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP frame = nullptr; } } + float end_time = speech_end_i * segment_duration; // 仅有结束时间 + this->end = static_cast(end_time); speech_start = -1; speech_offline_start = -1; } diff --git a/runtime/onnxruntime/src/commonfunc.h b/runtime/onnxruntime/src/commonfunc.h index 6fd553fe0..81fa2422e 100644 --- a/runtime/onnxruntime/src/commonfunc.h +++ b/runtime/onnxruntime/src/commonfunc.h @@ -12,6 +12,8 @@ typedef struct std::string stamp_sents; std::string tpass_msg; float snippet_time; + int64_t start = 0; + int64_t end = 0; }FUNASR_RECOG_RESULT; typedef struct diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp index 628641268..1eb8230bd 100644 --- a/runtime/onnxruntime/src/funasrruntime.cpp +++ b/runtime/onnxruntime/src/funasrruntime.cpp @@ -523,6 +523,8 @@ p_result->snippet_time = audio->GetTimeLen(); audio->Split(vad_online_handle, chunk_len, input_finished, mode); + p_result->start = audio->start; + p_result->end = audio->end; funasr::AudioFrame* frame = nullptr; while(audio->FetchChunck(frame) > 0){ @@ -695,6 +697,23 @@ return p_result->tpass_msg.c_str(); } + _FUNASRAPI const int64_t FunASRGetTpassStart(FUNASR_RESULT result) + { + funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result; + if(!p_result) + return 0; + + return p_result->start; + } + _FUNASRAPI const int64_t FunASRGetTpassEnd(FUNASR_RESULT result) + { + funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result; + if(!p_result) + return 0; + + return p_result->end; + } + _FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index) { funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result; diff --git a/runtime/websocket/bin/websocket-server-2pass.cpp b/runtime/websocket/bin/websocket-server-2pass.cpp index ff23e9d41..4834f3423 100644 --- a/runtime/websocket/bin/websocket-server-2pass.cpp +++ b/runtime/websocket/bin/websocket-server-2pass.cpp @@ -15,11 +15,19 @@ #include #include #include +#include +#include extern std::unordered_map hws_map_; extern int fst_inc_wts_; extern float global_beam_, lattice_beam_, am_scale_; +int64_t getCurrentTimeMillis() { + auto now = std::chrono::system_clock::now(); + auto millis = std::chrono::duration_cast(now.time_since_epoch()).count(); + return millis; +} + context_ptr WebSocketServer::on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl, std::string& s_certfile, @@ -57,7 +65,13 @@ context_ptr WebSocketServer::on_tls_init(tls_mode mode, return ctx; } -nlohmann::json handle_result(FUNASR_RESULT result) { +nlohmann::json handle_result(FUNASR_RESULT result, websocketpp::connection_hdl& hdl, std::map,std::owner_less>& data_map) { + std::shared_ptr data_msg = nullptr; + auto it = data_map.find(hdl); + if (it != data_map.end()) { + data_msg = it->second; + } + websocketpp::lib::error_code ec; nlohmann::json jsonresult; jsonresult["text"] = ""; @@ -67,12 +81,34 @@ nlohmann::json handle_result(FUNASR_RESULT result) { LOG(INFO) << "online_res :" << tmp_online_msg; jsonresult["text"] = tmp_online_msg; jsonresult["mode"] = "2pass-online"; + jsonresult["slice_type"] = 1; + jsonresult["index"] = data_msg->index; + + // 如果是第一句话的第一个实时结果或新的句子开始 + if (!data_msg->is_sentence_started) { + data_msg->start_time = FunASRGetTpassStart(result); // 记录句子的开始时间 + jsonresult["slice_type"] = 0; //0:一段话开始识别; 1:一段话识别中; 2:一段话识别结束 + data_msg->is_sentence_started = true; + } } + + data_msg->end_time = FunASRGetTpassEnd(result); // 记录句子的结束时间 + jsonresult["timestamp"] = data_msg->timestamp; + std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0); if (tmp_tpass_msg != "") { LOG(INFO) << "offline results : " << tmp_tpass_msg; jsonresult["text"] = tmp_tpass_msg; jsonresult["mode"] = "2pass-offline"; + + // 句子结束,记录结束时间 + jsonresult["start_time"] = data_msg->start_time; + jsonresult["end_time"] = data_msg->end_time; + jsonresult["slice_type"] = 2; + jsonresult["index"] = data_msg->index; + + data_msg->index++; //句子序号 + data_msg->is_sentence_started = false; // 重置句子状态 } std::string tmp_stamp_msg = FunASRGetStamp(result); @@ -98,6 +134,7 @@ nlohmann::json handle_result(FUNASR_RESULT result) { } // feed buffer to asr engine for decoder void WebSocketServer::do_decoder( + std::map,std::owner_less>& data_map, std::vector& buffer, websocketpp::connection_hdl& hdl, nlohmann::json& msg, @@ -158,10 +195,11 @@ void WebSocketServer::do_decoder( } if (Result) { websocketpp::lib::error_code ec; - nlohmann::json jsonresult = handle_result(Result); + nlohmann::json jsonresult = handle_result(Result, hdl, data_map); jsonresult["wav_name"] = wav_name; jsonresult["is_final"] = false; if (jsonresult["text"] != "") { + LOG(INFO) << "jsonresult: " << jsonresult.dump(4); if (is_ssl) { wss_server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text, ec); @@ -200,9 +238,10 @@ void WebSocketServer::do_decoder( } if (Result) { websocketpp::lib::error_code ec; - nlohmann::json jsonresult = handle_result(Result); + nlohmann::json jsonresult = handle_result(Result, hdl, data_map); jsonresult["wav_name"] = wav_name; jsonresult["is_final"] = true; + LOG(INFO) << "jsonresult: " << jsonresult.dump(4); if (is_ssl) { wss_server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text, ec); @@ -254,7 +293,7 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) { data_msg->msg["audio_fs"] = 16000; // default is 16k data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly data_msg->msg["is_eof"]=false; // if this connection is closed - data_msg->msg["svs_lang"]="auto"; + data_msg->msg["svs_lang"]="zh"; data_msg->msg["svs_itn"]=true; FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_); @@ -263,6 +302,10 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) { std::make_shared>>(2); data_msg->strand_ = std::make_shared(io_decoder_); + data_msg->is_sentence_started = false; + + data_msg->timestamp = getCurrentTimeMillis(); + data_map.emplace(hdl, data_msg); }catch (std::exception const& e) { std::cerr << "Error: " << e.what() << std::endl; @@ -501,6 +544,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, std::vector> hotwords_embedding_(*(msg_data->hotwords_embedding)); msg_data->strand_->post( std::bind(&WebSocketServer::do_decoder, this, + data_map, std::move(*(sample_data_p.get())), std::move(hdl), std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())), std::move(hotwords_embedding_), @@ -550,6 +594,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, std::vector> hotwords_embedding_(*(msg_data->hotwords_embedding)); msg_data->strand_->post( std::bind(&WebSocketServer::do_decoder, this, + data_map, std::move(subvector), std::move(hdl), std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())), diff --git a/runtime/websocket/bin/websocket-server-2pass.h b/runtime/websocket/bin/websocket-server-2pass.h index e61a93b2d..3ba63089c 100644 --- a/runtime/websocket/bin/websocket-server-2pass.h +++ b/runtime/websocket/bin/websocket-server-2pass.h @@ -61,7 +61,13 @@ typedef struct { std::string online_res = ""; std::string tpass_res = ""; std::shared_ptr strand_; // for data execute in order - FUNASR_DEC_HANDLE decoder_handle=nullptr; + FUNASR_DEC_HANDLE decoder_handle=nullptr; + + bool is_sentence_started = false; + int64_t start_time = 0; + int64_t end_time = 0; + int64_t index = 0; + int64_t timestamp = 0; } FUNASR_MESSAGE; // See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about @@ -114,7 +120,9 @@ class WebSocketServer { server_->clear_access_channels(websocketpp::log::alevel::all); } } - void do_decoder(std::vector& buffer, websocketpp::connection_hdl& hdl, + void do_decoder(std::map,std::owner_less>& data_map, + std::vector& buffer, + websocketpp::connection_hdl& hdl, nlohmann::json& msg, std::vector>& punc_cache, std::vector> &hotwords_embedding,