Skip to content

Commit c6bfb27

Browse files
authored
Fix word confidence return (#15249)
1 parent 5487c7e commit c6bfb27

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

nemo/collections/asr/parts/utils/asr_confidence_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from omegaconf import DictConfig, OmegaConf
2323

2424
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
25-
from nemo.utils import logging
2625

2726

2827
class ConfidenceMethodConstants:
@@ -447,7 +446,7 @@ def _aggregate_token_confidence_subwords_sentencepiece(
447446
prev_underline = False
448447
for i, token_id in enumerate(token_ids):
449448
token = self.decode_ids_to_tokens([int(token_id)])[0]
450-
token_text = self.decode_tokens_to_str([int(token_id)])
449+
token_text = self.decode_ids_to_str([int(token_id)])
451450
# treat `<unk>` as a separate word regardless of the next token
452451
# to match the result of `tokenizer.ids_to_text`
453452
if (token != token_text or prev_unk) and i > j:

nemo/collections/asr/parts/utils/chunking_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def merge_parallel_chunks(hypotheses, encoded_len, model, timestamps, subsamplin
7070
timestamp=([] if not timestamps else {'word': [], 'segment': []}),
7171
)
7272
merged_hypotheses = join_y_sequence(merged_hypotheses, hypotheses)
73+
merged_hypotheses = join_confidence_values(merged_hypotheses, hypotheses)
7374
merged_hypotheses.text = final_text
7475
# Merge timestamps and add word and segment level timestamps
7576
if timestamps:
@@ -99,6 +100,44 @@ def join_y_sequence(merged_hypothesis, hypotheses):
99100
return merged_hypothesis
100101

101102

103+
def join_confidence_values(merged_hypothesis, hypotheses):
104+
"""
105+
Concatenate confidence values from multiple hypotheses into a single sequence.
106+
107+
Args:
108+
merged_hypothesis: Target hypothesis to update with concatenated confidence
109+
hypotheses: List of hypotheses containing confidence values
110+
111+
Returns:
112+
Hypothesis: Updated merged_hypothesis with concatenated confidence values
113+
"""
114+
# Merge frame_confidence
115+
frame_confidences = [h.frame_confidence for h in hypotheses if h.frame_confidence is not None]
116+
if frame_confidences:
117+
if isinstance(frame_confidences[0], torch.Tensor):
118+
merged_hypothesis.frame_confidence = torch.cat(frame_confidences)
119+
elif isinstance(frame_confidences[0], list):
120+
merged_hypothesis.frame_confidence = [c for conf_list in frame_confidences for c in conf_list]
121+
122+
# Merge token_confidence
123+
token_confidences = [h.token_confidence for h in hypotheses if h.token_confidence is not None]
124+
if token_confidences:
125+
if isinstance(token_confidences[0], torch.Tensor):
126+
merged_hypothesis.token_confidence = torch.cat(token_confidences)
127+
elif isinstance(token_confidences[0], list):
128+
merged_hypothesis.token_confidence = [c for conf_list in token_confidences for c in conf_list]
129+
130+
# Merge word_confidence
131+
word_confidences = [h.word_confidence for h in hypotheses if h.word_confidence is not None]
132+
if word_confidences:
133+
if isinstance(word_confidences[0], torch.Tensor):
134+
merged_hypothesis.word_confidence = torch.cat(word_confidences)
135+
elif isinstance(word_confidences[0], list):
136+
merged_hypothesis.word_confidence = [c for conf_list in word_confidences for c in conf_list]
137+
138+
return merged_hypothesis
139+
140+
102141
def join_timestamp_and_add_word_and_segment_level_timestamps(
103142
merged_hypotheses, hypotheses, chunk_offsets, subsampling_factor, window_stride, decoding, merged_tokens=None
104143
):
@@ -307,6 +346,9 @@ def merge_hypotheses_of_same_audio(hypotheses_list, timestamps, subsampling_fact
307346

308347
merged_hypothesis.y_sequence = torch.cat([h.y_sequence for h in hypotheses_list])
309348

349+
# Merge confidence values from all hypotheses
350+
merged_hypothesis = join_confidence_values(merged_hypothesis, hypotheses_list)
351+
310352
# Create final text by joining text from all hypotheses
311353
text_parts = []
312354
for hyp in hypotheses_list:

0 commit comments

Comments
 (0)