Skip to content

Commit 18a6ed5

Browse files
authored
Preserve more context after endpointing in transducer (#2061)
1 parent da4aad1 commit 18a6ed5

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

sherpa-onnx/csrc/online-recognizer-transducer-impl.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,16 +388,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
388388
auto r = decoder_->GetEmptyResult();
389389
auto last_result = s->GetResult();
390390
// if last result is not empty, then
391-
// preserve last tokens as the context for next result
391+
// truncate all last hyps and save as the context for next result
392392
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
393-
std::vector<int64_t> context(last_result.tokens.end() - context_size,
394-
last_result.tokens.end());
393+
for (const auto &it : last_result.hyps) {
394+
auto h = it.second;
395+
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
396+
h.ys.end()),
397+
h.log_prob});
398+
}
395399

396-
Hypotheses context_hyp({{context, 0}});
397-
r.hyps = std::move(context_hyp);
398-
r.tokens = std::move(context);
400+
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
401+
last_result.tokens.end());
399402
}
400403

404+
// but reset all contextual biasing graph states to root
401405
if (config_.decoding_method == "modified_beam_search" &&
402406
nullptr != s->GetContextGraph()) {
403407
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {

0 commit comments

Comments
 (0)