Skip to content

Commit 6a1efd8

Browse files
online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) (#2129)
* online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) - added `reset_encoder` boolean member into the OnlineRecognizerConfig class - by default the encoder is not reset * pybind11, adding empty symbols for disabled modules (tts, diarization) * reset_encoder, add default value (false) [pybind11]
1 parent 921c437 commit 6a1efd8

File tree

6 files changed

+53
-10
lines changed

6 files changed

+53
-10
lines changed

sherpa-onnx/csrc/online-recognizer-transducer-impl.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
382382
}
383383
}
384384

385-
// reset encoder states
386-
// s->SetStates(model_->GetEncoderInitStates());
387-
388385
auto r = decoder_->GetEmptyResult();
389386
auto last_result = s->GetResult();
390-
// if last result is not empty, then
391-
// truncate all last hyps and save as the context for next result
387+
392388
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
389+
// if last result is not empty, then
390+
// truncate all last hyps and save as the 'ys' context for next result
391+
// (the encoder state buffers are kept)
393392
for (const auto &it : last_result.hyps) {
394393
auto h = it.second;
395394
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
@@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
399398

400399
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
401400
last_result.tokens.end());
401+
} else {
402+
if(config_.reset_encoder) {
403+
// reset encoder states, use blanks as 'ys' context
404+
s->SetStates(model_->GetEncoderInitStates());
405+
}
402406
}
403407

404408
// but reset all contextual biasing graph states to root

sherpa-onnx/csrc/online-recognizer.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
121121
"rule-fars", &rule_fars,
122122
"If not empty, it specifies fst archives for inverse text normalization. "
123123
"If there are multiple archives, they are separated by a comma.");
124+
125+
po->Register("reset-encoder", &reset_encoder,
126+
"True to reset encoder_state on an endpoint after empty segment."
127+
"Done in `Reset()` method, after an endpoint was detected.");
124128
}
125129

126130
bool OnlineRecognizerConfig::Validate() const {
@@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const {
198202
os << "blank_penalty=" << blank_penalty << ", ";
199203
os << "temperature_scale=" << temperature_scale << ", ";
200204
os << "rule_fsts=\"" << rule_fsts << "\", ";
201-
os << "rule_fars=\"" << rule_fars << "\")";
205+
os << "rule_fars=\"" << rule_fars << "\", ";
206+
os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")";
202207

203208
return os.str();
204209
}

sherpa-onnx/csrc/online-recognizer.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ struct OnlineRecognizerConfig {
7979
OnlineLMConfig lm_config;
8080
EndpointConfig endpoint_config;
8181
OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
82+
8283
bool enable_endpoint = true;
8384

8485
std::string decoding_method = "greedy_search";
@@ -101,6 +102,11 @@ struct OnlineRecognizerConfig {
101102
// If there are multiple FST archives, they are applied from left to right.
102103
std::string rule_fars;
103104

105+
// True to reset encoder_state on an endpoint after empty segment.
106+
// Done in `Reset()` method, after an endpoint was detected,
107+
// currently only in `OnlineRecognizerTransducerImpl`.
108+
bool reset_encoder = false;
109+
104110
/// used only for modified_beam_search, if hotwords_buf is non-empty,
105111
/// the hotwords will be loaded from the buffered string instead of from the
106112
/// "hotwords_file"
@@ -116,7 +122,8 @@ struct OnlineRecognizerConfig {
116122
bool enable_endpoint, const std::string &decoding_method,
117123
int32_t max_active_paths, const std::string &hotwords_file,
118124
float hotwords_score, float blank_penalty, float temperature_scale,
119-
const std::string &rule_fsts, const std::string &rule_fars)
125+
const std::string &rule_fsts, const std::string &rule_fars,
126+
bool reset_encoder)
120127
: feat_config(feat_config),
121128
model_config(model_config),
122129
lm_config(lm_config),
@@ -130,7 +137,8 @@ struct OnlineRecognizerConfig {
130137
blank_penalty(blank_penalty),
131138
temperature_scale(temperature_scale),
132139
rule_fsts(rule_fsts),
133-
rule_fars(rule_fars) {}
140+
rule_fars(rule_fars),
141+
reset_encoder(reset_encoder) {}
134142

135143
void Register(ParseOptions *po);
136144
bool Validate() const;

sherpa-onnx/python/csrc/online-recognizer.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
5858
const OnlineLMConfig &, const EndpointConfig &,
5959
const OnlineCtcFstDecoderConfig &, bool,
6060
const std::string &, int32_t, const std::string &, float,
61-
float, float, const std::string &, const std::string &>(),
61+
float, float, const std::string &, const std::string &, bool>(),
6262
py::arg("feat_config"), py::arg("model_config"),
6363
py::arg("lm_config") = OnlineLMConfig(),
6464
py::arg("endpoint_config") = EndpointConfig(),
@@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
6767
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
6868
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
6969
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
70-
py::arg("rule_fars") = "")
70+
py::arg("rule_fars") = "", py::arg("reset_encoder") = false)
7171
.def_readwrite("feat_config", &PyClass::feat_config)
7272
.def_readwrite("model_config", &PyClass::model_config)
7373
.def_readwrite("lm_config", &PyClass::lm_config)
@@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
8282
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
8383
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
8484
.def_readwrite("rule_fars", &PyClass::rule_fars)
85+
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
8586
.def("__str__", &PyClass::ToString);
8687
}
8788

sherpa-onnx/python/csrc/sherpa-onnx.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
7575

7676
#if SHERPA_ONNX_ENABLE_TTS == 1
7777
PybindOfflineTts(&m);
78+
#else
79+
/* Define "empty" TTS sybmbols */
80+
m.attr("OfflineTtsKokoroModelConfig") = py::none();
81+
m.attr("OfflineTtsMatchaModelConfig") = py::none();
82+
m.attr("OfflineTtsModelConfig") = py::none();
83+
m.attr("OfflineTtsVitsModelConfig") = py::none();
84+
m.attr("GeneratedAudio") = py::none();
85+
m.attr("OfflineTtsConfig") = py::none();
86+
m.attr("OfflineTts") = py::none();
7887
#endif
7988

8089
PybindSpeakerEmbeddingExtractor(&m);
@@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
8594
PybindFastClustering(&m);
8695
PybindOfflineSpeakerDiarizationResult(&m);
8796
PybindOfflineSpeakerDiarization(&m);
97+
#else
98+
/* Define "empty" diarization sybmbols */
99+
m.attr("FastClusteringConfig") = py::none();
100+
m.attr("FastClustering") = py::none();
101+
m.attr("OfflineSpeakerDiarizationSegment") = py::none();
102+
m.attr("OfflineSpeakerDiarizationResult") = py::none();
103+
m.attr("OfflineSpeakerSegmentationPyannoteModelConfig") = py::none();
104+
m.attr("OfflineSpeakerSegmentationModelConfig") = py::none();
105+
m.attr("OfflineSpeakerDiarizationConfig") = py::none();
106+
m.attr("OfflineSpeakerDiarization") = py::none();
88107
#endif
89108

90109
PybindAlsa(&m);

sherpa-onnx/python/sherpa_onnx/online_recognizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def from_transducer(
6868
lm_scale: float = 0.1,
6969
lm_shallow_fusion: bool = True,
7070
temperature_scale: float = 2.0,
71+
reset_encoder: bool = False,
7172
debug: bool = False,
7273
rule_fsts: str = "",
7374
rule_fars: str = "",
@@ -162,6 +163,10 @@ def from_transducer(
162163
Temperature scaling for output symbol confidence estiamation.
163164
It affects only confidence values, the decoding uses the original
164165
logits without temperature.
166+
reset_encoder:
167+
True to reset `encoder_state` on an endpoint after empty segment.
168+
Done in `Reset()` method, after an endpoint was detected,
169+
currently only in `OnlineRecognizerTransducerImpl`.
165170
model_type:
166171
Online transducer model type. Valid values are: conformer, lstm,
167172
zipformer, zipformer2. All other values lead to loading the model twice.
@@ -305,6 +310,7 @@ def from_transducer(
305310
temperature_scale=temperature_scale,
306311
rule_fsts=rule_fsts,
307312
rule_fars=rule_fars,
313+
reset_encoder=reset_encoder,
308314
)
309315

310316
self.recognizer = _Recognizer(recognizer_config)

0 commit comments

Comments
 (0)