Skip to content

Commit 454fabc

Browse files
monica-sekoyannune-tadevosyanmonica-sekoyan
authored
Canary2 with NFA (#14121)
* adding nfa to canary Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> * remove comments Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * modify external model loading Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * fix audio padding Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * reseting Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * handle non-possible alignment Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * add offset refinement Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * Revert "reseting" This reverts commit 6d74ad0. Signed-off-by: monica-sekoyan <msekoyan@vidia.com> * Revert "Apply isort and black reformatting" This reverts commit 1d8c363. Signed-off-by: monica-sekoyan <msekoyan@vidia.com> * handle merge case for timestamps Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * add timestamp_type Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * add timestamps support chunked inference Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * refactor ctc timestamps to use utils Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * correct restore_token_cased with unk_token Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * use timestamps utils in rnnt_decoding Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * change external timestamps asr model loading Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * add forced aligned method tests Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * modify nfa to match new setup and utils Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * remove unused imports Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * merge conflicts Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * remove unused errors Signed-off-by: monica-sekoyan <msekoyan@vidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * remove unused import Signed-off-by: monica-sekoyan <msekoyan@vidia.com> * addressing comments, linting and flake8 Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * handle decode_ids_to_str change Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * correct usage of decode_tokens_to_str Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * update nfa docs Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * revert jupyter settings Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * correct description Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * make private Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * rewrite restore_timestamps_asr_model Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * fix word offset logic Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> --------- Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> Signed-off-by: Nune <ntadevosyan@nvidia.com> Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Signed-off-by: monica-sekoyan <msekoyan@vidia.com> Co-authored-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> Co-authored-by: Nune <ntadevosyan@nvidia.com> Co-authored-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Co-authored-by: monica-sekoyan <msekoyan@vidia.com>
1 parent e295dbc commit 454fabc

29 files changed

+2743
-1819
lines changed

docs/source/tools/nemo_forced_aligner.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ Optional parameters:
6464

6565
* ``use_local_attention``: boolean flag specifying whether to try to use local attention for the ASR Model (will only work if the ASR Model is a Conformer model). If local attention is used, we will set the local attention context size to [64,64].
6666

67-
* ``additional_segment_grouping_separator``: an optional string used to separate the text into smaller segments. If this is not specified, then the whole text will be treated as a single segment. (Default: ``None``. Cannot be empty string or space (" "), as NFA will automatically produce word-level timestamps for substrings separated by spaces).
67+
* ``additional_segment_grouping_separator``: a list of strings used to separate the text into smaller segments. If set to ``None``, then the whole text will be treated as a single segment. (Default: ``['.', '?', '!', '...']``. Cannot be empty string or space (" "), as NFA will automatically produce word-level timestamps for substrings separated by spaces).
6868

69-
.. note:: the ``additional_segment_grouping_separator`` will be removed from the reference text and all the output files, ie it is treated as a marker which is not part of the reference text. The separator will essentially be treated as a space, and any additional spaces around it will be amalgamated into one, i.e. if ``additional_segment_grouping_separator="|"``, the following texts will be treated equivalently: ``“abc|def”``, ``“abc |def”``, ``“abc| def”``, ``“abc | def"``.
69+
.. note:: Starting in NeMo 2.5.0, separators are preserved in segment text after splitting. if ``additional_segment_grouping_separator="['.', '?', '!', '...']"`` (as is the default), then the text ``"Hi, have you updated your NeMo? Yes. Sure!"`` will result in the following segments ``["Hi, have you updated your NeMo?", "Yes.", "Sure!"]``.
7070

7171
* ``remove_blank_tokens_from_ctm``: a boolean denoting whether to remove <blank> tokens from token-level output CTMs. (Default: False).
7272

nemo/collections/asr/metrics/bleu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def update(
184184
tgt_lenths_cpu_tensor = targets_lengths.long().cpu()
185185
for idx, tgt_len in enumerate(tgt_lenths_cpu_tensor):
186186
target = targets_cpu_tensor[idx][:tgt_len].tolist()
187-
reference = self.decoding.decode_tokens_to_str(target)
187+
reference = self.decoding.decode_ids_to_str(target)
188188
tok = tokenizers[idx] if tokenizers else None # `None` arg uses default tokenizer
189189

190190
# TODO: the backend implementation of this has a lot of cpu to gpu operations. Should reimplement

nemo/collections/asr/metrics/wer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def update(
324324
for ind in range(targets_cpu_tensor.shape[0]):
325325
tgt_len = tgt_lenths_cpu_tensor[ind].item()
326326
target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist()
327-
reference = self.decoding.decode_tokens_to_str(target)
327+
reference = self.decoding.decode_ids_to_str(target)
328328
references.append(reference)
329329
hypotheses = (
330330
self.decode(predictions, predictions_lengths, predictions_mask, input_ids)

nemo/collections/asr/models/aed_multitask_models.py

Lines changed: 100 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dataclasses import dataclass, field
1919
from math import ceil
2020
from typing import Any, Dict, List, Optional, Union
21+
2122
import numpy as np
2223
import torch
2324
from lightning.pytorch import Trainer
@@ -40,14 +41,18 @@
4041
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig
4142
from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier
4243
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
43-
from nemo.collections.asr.parts.utils.timestamp_utils import process_aed_timestamp_outputs
44+
from nemo.collections.asr.parts.utils.timestamp_utils import (
45+
get_forced_aligned_timestamps_with_external_model,
46+
process_aed_timestamp_outputs,
47+
)
4448
from nemo.collections.common import tokenizers
4549
from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config
4650
from nemo.collections.common.metrics import GlobalAverageLossMetric
4751
from nemo.collections.common.parts import transformer_weights_init
4852
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
4953
from nemo.collections.common.prompts.formatter import PromptFormatter
5054
from nemo.core.classes.common import typecheck
55+
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
5156
from nemo.core.neural_types import (
5257
AudioSignal,
5358
ChannelType,
@@ -59,6 +64,7 @@
5964
SpectrogramType,
6065
)
6166
from nemo.utils import logging, model_utils
67+
from nemo.utils.app_state import AppState
6268

6369
__all__ = ['EncDecMultiTaskModel']
6470

@@ -241,6 +247,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
241247
# Setup encoder adapters (from ASRAdapterModelMixin)
242248
self.setup_adapters()
243249

250+
timestamps_asr_model = self.__restore_timestamps_asr_model()
251+
# Using object.__setattr__ to bypass PyTorch's module registration
252+
object.__setattr__(self, 'timestamps_asr_model', timestamps_asr_model)
253+
244254
def change_decoding_strategy(self, decoding_cfg: DictConfig):
245255
"""
246256
Changes decoding strategy used during Multi Task decoding process.
@@ -518,16 +528,18 @@ def transcribe(
518528
as paths2audio_files
519529
"""
520530
if timestamps is not None:
521-
# TODO: Handle this key gracefully later
522-
523-
if timestamps is True:
524-
timestamps = 'yes'
525-
elif timestamps is False:
526-
timestamps = 'no'
531+
if self.timestamps_asr_model is None:
532+
# TODO: Handle this key gracefully later
533+
if timestamps is True:
534+
timestamps = 'yes'
535+
elif timestamps is False:
536+
timestamps = 'no'
537+
else:
538+
timestamps = str(timestamps)
539+
assert timestamps in ('yes', 'no', 'timestamp', 'notimestamp', '1', '0')
540+
prompt['timestamp'] = timestamps
527541
else:
528-
timestamps = str(timestamps)
529-
assert timestamps in ('yes', 'no', 'timestamp', 'notimestamp', '1', '0')
530-
prompt['timestamp'] = timestamps
542+
prompt['timestamp'] = 'no'
531543

532544
if override_config is None:
533545
trcfg = MultiTaskTranscriptionConfig(
@@ -538,6 +550,7 @@ def transcribe(
538550
augmentor=augmentor,
539551
verbose=verbose,
540552
prompt=prompt,
553+
timestamps=timestamps,
541554
)
542555
else:
543556
if not isinstance(override_config, MultiTaskTranscriptionConfig):
@@ -546,6 +559,7 @@ def transcribe(
546559
f"but got {type(override_config)}"
547560
)
548561
trcfg = override_config
562+
trcfg.timestamps = timestamps
549563

550564
return super().transcribe(audio=audio, override_config=trcfg)
551565

@@ -856,6 +870,9 @@ def _transcribe_on_begin(self, audio, trcfg: MultiTaskTranscriptionConfig):
856870
trcfg._internal.primary_language = self.tokenizer.langs[0]
857871
logging.debug(f"Transcribing with default setting of {trcfg._internal.primary_language}.")
858872

873+
if trcfg.timestamps and self.timestamps_asr_model is not None:
874+
self.timestamps_asr_model.to(trcfg._internal.device)
875+
859876
def _transcribe_input_manifest_processing(
860877
self, audio_files: List[str], temp_dir: str, trcfg: MultiTaskTranscriptionConfig
861878
) -> Dict[str, Any]:
@@ -955,6 +972,7 @@ def _transcribe_forward(
955972
encoder_states=enc_states,
956973
encoder_mask=enc_mask,
957974
decoder_input_ids=decoder_input_ids,
975+
batch=batch,
958976
)
959977

960978
def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType:
@@ -976,6 +994,7 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
976994
enc_states = outputs.pop('encoder_states')
977995
enc_mask = outputs.pop('encoder_mask')
978996
decoder_input_ids = outputs.pop('decoder_input_ids')
997+
batch = outputs.pop('batch')
979998

980999
del log_probs, encoded_len
9811000

@@ -988,10 +1007,19 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
9881007

9891008
del enc_states, enc_mask, decoder_input_ids
9901009

991-
hypotheses = process_aed_timestamp_outputs(
992-
hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
993-
)
994-
1010+
if trcfg.timestamps and self.timestamps_asr_model is not None:
1011+
hypotheses = get_forced_aligned_timestamps_with_external_model(
1012+
audio=[audio.squeeze()[:audio_len] for audio, audio_len in zip(batch.audio, batch.audio_lens)],
1013+
batch_size=len(batch.audio),
1014+
external_ctc_model=self.timestamps_asr_model,
1015+
main_model_predictions=hypotheses,
1016+
timestamp_type=['word', 'segment'],
1017+
viterbi_device=trcfg._internal.device,
1018+
)
1019+
elif trcfg.timestamps:
1020+
hypotheses = process_aed_timestamp_outputs(
1021+
hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
1022+
)
9951023
return hypotheses
9961024

9971025
def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
@@ -1062,6 +1090,7 @@ def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: Mult
10621090
# This method is a legacy helper for Canary that checks whether prompt slot values were provided
10631091
# in the input manifest and if not, it injects the defaults.
10641092
out_json_items = []
1093+
timestamps_required = False
10651094
for item in json_items:
10661095
if isinstance(item, str):
10671096
# assume it is a path to audio file
@@ -1099,7 +1128,21 @@ def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: Mult
10991128
if k not in entry:
11001129
# last-chance fallback injecting legacy Canary defaults if none were provided.
11011130
entry[k] = default_turn.get(k, dv)
1131+
if k == "timestamp":
1132+
if (
1133+
str(entry[k]).lower() not in ['notimestamp', "no", "false", "0"]
1134+
and self.timestamps_asr_model is not None
1135+
):
1136+
timestamps_required = True
1137+
entry[k] = 'notimestamp'
11021138
out_json_items.append(entry)
1139+
1140+
if timestamps_required:
1141+
trcfg.timestamps = True
1142+
logging.warning(
1143+
"Timestamps are enabled for at least one of the input items. "
1144+
"Setting timestamps to True for all the input items, as the current model is using external ASR model for alignment."
1145+
)
11031146
return out_json_items
11041147

11051148
@classmethod
@@ -1113,7 +1156,12 @@ def get_transcribe_config(cls) -> MultiTaskTranscriptionConfig:
11131156
return MultiTaskTranscriptionConfig()
11141157

11151158
def predict_step(
1116-
self, batch: PromptedAudioToTextMiniBatch, batch_idx=0, dataloader_idx=0, has_processed_signal=False
1159+
self,
1160+
batch: PromptedAudioToTextMiniBatch,
1161+
batch_idx=0,
1162+
dataloader_idx=0,
1163+
has_processed_signal=False,
1164+
timestamps=False,
11171165
):
11181166
if has_processed_signal:
11191167
processed_signal = batch.audio
@@ -1140,9 +1188,10 @@ def predict_step(
11401188
return_hypotheses=False,
11411189
)
11421190

1143-
hypotheses = process_aed_timestamp_outputs(
1144-
hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
1145-
)
1191+
if timestamps and self.timestamps_asr_model is None:
1192+
hypotheses = process_aed_timestamp_outputs(
1193+
hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
1194+
)
11461195

11471196
if batch.cuts:
11481197
return list(zip(batch.cuts, hypotheses))
@@ -1182,6 +1231,39 @@ def oomptimizer_schema(self) -> dict:
11821231
],
11831232
}
11841233

1234+
def __restore_timestamps_asr_model(self):
1235+
"""
1236+
This method is used to restore the external timestamp ASR model that will be used for forced alignment in `.transcribe()`.
1237+
The config and weights are expected to be in the main .nemo file and be named `timestamps_asr_model_config.yaml` and `timestamps_asr_model_weights.ckpt` respectively.
1238+
"""
1239+
app_state = AppState()
1240+
model_restore_path = app_state.model_restore_path
1241+
1242+
if not model_restore_path:
1243+
return None
1244+
1245+
save_restore_connector = SaveRestoreConnector()
1246+
1247+
filter_fn = lambda name: "timestamps_asr_model" in name
1248+
members = save_restore_connector._filtered_tar_info(model_restore_path, filter_fn=filter_fn)
1249+
1250+
if not members:
1251+
return None
1252+
1253+
try:
1254+
save_restore_connector.model_config_yaml = "timestamps_asr_model_config.yaml"
1255+
save_restore_connector.model_weights_ckpt = "timestamps_asr_model_weights.ckpt"
1256+
external_timestamps_model = ASRModel.restore_from(
1257+
model_restore_path, save_restore_connector=save_restore_connector
1258+
)
1259+
external_timestamps_model.eval()
1260+
except Exception as e:
1261+
raise RuntimeError(
1262+
f"Error restoring external timestamps ASR model with timestamps_asr_model_config.yaml and timestamps_asr_model_weights.ckpt: {e}"
1263+
)
1264+
1265+
return external_timestamps_model
1266+
11851267

11861268
def parse_multitask_prompt(prompt: dict | None) -> list[dict]:
11871269
if prompt is None or not prompt:

nemo/collections/asr/models/k2_aligner_model.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def _init_ctc_alignment_specific(self, cfg: DictConfig):
8181
self.graph_decoder.split_batch_size = self.decode_batch_size
8282
else:
8383
self.graph_decoder = ViterbiDecoderWithGraph(
84-
num_classes=self.blank_id, split_batch_size=self.decode_batch_size,
84+
num_classes=self.blank_id,
85+
split_batch_size=self.decode_batch_size,
8586
)
8687
# override decoder args if a config is provided
8788
decoder_module_cfg = cfg.get("decoder_module_cfg", None)
@@ -119,16 +120,18 @@ def _init_rnnt_alignment_specific(self, cfg: DictConfig):
119120

120121
from nemo.collections.asr.parts.k2.utils import apply_rnnt_prune_ranges, get_uniform_rnnt_prune_ranges
121122

122-
self.prepare_pruned_outputs = lambda encoder_outputs, encoded_len, decoder_outputs, transcript_len: apply_rnnt_prune_ranges(
123-
encoder_outputs,
124-
decoder_outputs,
125-
get_uniform_rnnt_prune_ranges(
126-
encoded_len,
127-
transcript_len,
128-
self.predictor_window_size + 1,
129-
self.predictor_step_size,
130-
encoder_outputs.size(1),
131-
).to(device=encoder_outputs.device),
123+
self.prepare_pruned_outputs = (
124+
lambda encoder_outputs, encoded_len, decoder_outputs, transcript_len: apply_rnnt_prune_ranges(
125+
encoder_outputs,
126+
decoder_outputs,
127+
get_uniform_rnnt_prune_ranges(
128+
encoded_len,
129+
transcript_len,
130+
self.predictor_window_size + 1,
131+
self.predictor_step_size,
132+
encoder_outputs.size(1),
133+
).to(device=encoder_outputs.device),
134+
)
132135
)
133136

134137
from nemo.collections.asr.parts.k2.classes import GraphModuleConfig
@@ -231,9 +234,9 @@ def _rnnt_joint_pruned(
231234
def _apply_prob_suppress(self, log_probs: torch.Tensor) -> torch.Tensor:
232235
"""Multiplies probability of an element with index self.prob_suppress_index by self.prob_suppress_value times
233236
with stochasticity preservation of the log_probs tensor.
234-
237+
235238
Often used to suppress <blank> probability of the output of a CTC model.
236-
239+
237240
Example:
238241
For
239242
- log_probs = torch.log(torch.tensor([0.015, 0.085, 0.9]))
@@ -305,7 +308,7 @@ def _predict_impl_rnnt_argmax(
305308
# we have no token probabilities for the argmax rnnt setup
306309
token_prob = [1.0] * len(tokens)
307310
if self.word_output:
308-
words = [w for w in self._model.decoding.decode_tokens_to_str(pred_ids).split(" ") if w != ""]
311+
words = [w for w in self._model.decoding.decode_ids_to_str(pred_ids).split(" ") if w != ""]
309312
words, word_begin, word_len, word_prob = (
310313
self._process_tokens_to_words(tokens, token_begin, token_len, token_prob, words)
311314
if hasattr(self._model, "tokenizer")
@@ -411,7 +414,7 @@ def _process_char_with_space_to_words(
411414
def _results_to_ctmUnits(
412415
self, s_id: int, pred: torch.Tensor, prob: torch.Tensor
413416
) -> Tuple[int, List['FrameCtmUnit']]:
414-
"""Transforms predictions with probabilities to a list of FrameCtmUnit objects,
417+
"""Transforms predictions with probabilities to a list of FrameCtmUnit objects,
415418
containing frame-level alignment information (label, start, duration, probability), for a given sample id.
416419
417420
Alignment information can be either token-based (char, wordpiece, ...) or word-based.
@@ -440,7 +443,7 @@ def _results_to_ctmUnits(
440443
for i, j in zip(non_blank_idx.tolist(), non_blank_idx[1:].tolist() + [len(pred)])
441444
]
442445
if self.word_output:
443-
words = wer_module.decode_tokens_to_str(pred_ids).split(" ")
446+
words = wer_module.decode_ids_to_str(pred_ids).split(" ")
444447
words, word_begin, word_len, word_prob = (
445448
self._process_tokens_to_words(tokens, token_begin, token_len, token_prob, words)
446449
if hasattr(self._model, "tokenizer")
@@ -539,7 +542,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0) -> List[Tuple[int, 'F
539542

540543
@torch.no_grad()
541544
def transcribe(
542-
self, manifest: List[str], batch_size: int = 4, num_workers: int = None, verbose: bool = True,
545+
self,
546+
manifest: List[str],
547+
batch_size: int = 4,
548+
num_workers: int = None,
549+
verbose: bool = True,
543550
) -> List['FrameCtmUnit']:
544551
"""
545552
Does alignment. Use this method for debugging and prototyping.

nemo/collections/asr/parts/mixins/mixins.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os
1717
import shutil
1818
import tarfile
19-
import unicodedata
2019
from abc import ABC, abstractmethod
2120
from typing import List
2221

@@ -29,6 +28,10 @@
2928
from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder
3029
from nemo.collections.asr.parts.utils import asr_module_utils
3130
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
31+
from nemo.collections.asr.parts.utils.tokenizer_utils import (
32+
extract_capitalized_tokens_from_vocab,
33+
extract_punctuation_from_vocab,
34+
)
3235
from nemo.collections.common import tokenizers
3336
from nemo.utils import app_state, logging
3437

@@ -482,11 +485,8 @@ def _extract_tokenizer_from_config(self, tokenizer_cfg: DictConfig, dir: str):
482485
def _derive_tokenizer_properties(self):
483486
vocab = self.tokenizer.tokenizer.get_vocab()
484487

485-
capitalized_tokens = {token.strip() for token in vocab if any(char.isupper() for char in token)}
486-
self.tokenizer.supports_capitalization = bool(capitalized_tokens)
487-
488-
punctuation = {char for token in vocab for char in token if unicodedata.category(char).startswith('P')}
489-
self.tokenizer.supported_punctuation = punctuation
488+
self.tokenizer.supports_capitalization = bool(extract_capitalized_tokens_from_vocab(vocab))
489+
self.tokenizer.supported_punctuation = extract_punctuation_from_vocab(vocab)
490490

491491

492492
class ASRModuleMixin(ASRAdapterModelMixin):

0 commit comments

Comments
 (0)