diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 4c1e4f46fef8..f4bdc16f1ea1 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -137,7 +137,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) - cfg = model_utils.maybe_update_config_version(cfg) + cfg = model_utils.maybe_update_config_version(cfg, make_copy=False) _config_check(cfg) self.prompt_format = cfg.prompt_format @@ -1320,18 +1320,34 @@ def __restore_timestamps_asr_model(self): The config and weights are expected to be in the main .nemo file and be named `timestamps_asr_model_config.yaml` and `timestamps_asr_model_weights.ckpt` respectively. """ app_state = AppState() + nemo_file_folder = app_state.nemo_file_folder # Already-extracted temp directory model_restore_path = app_state.model_restore_path if not model_restore_path: return None save_restore_connector = SaveRestoreConnector() + save_restore_connector.model_config_yaml = os.path.join(nemo_file_folder, "timestamps_asr_model_config.yaml") + save_restore_connector.model_weights_ckpt = os.path.join(nemo_file_folder, "timestamps_asr_model_weights.ckpt") - filter_fn = lambda name: "timestamps_asr_model" in name - members = save_restore_connector._filtered_tar_info(model_restore_path, filter_fn=filter_fn) + # Check if the model_restore_path is already an extracted directory (which happens during restore_from) + # If so, use it directly to avoid double extraction + if app_state.nemo_file_folder and os.path.isdir(app_state.nemo_file_folder): + # Verify that the timestamp model components exist in the extracted folder + config_exists = os.path.exists(save_restore_connector.model_config_yaml) + weights_exists = os.path.exists(save_restore_connector.model_weights_ckpt) - if not members: - return None + if not (config_exists and weights_exists): + return None + + save_restore_connector.model_extracted_dir = app_state.nemo_file_folder + + else: + filter_fn = lambda name: "timestamps_asr_model" in name + members = save_restore_connector._filtered_tar_info(model_restore_path, filter_fn=filter_fn) + + if not members: + return None try: save_restore_connector.model_config_yaml = "timestamps_asr_model_config.yaml" @@ -1340,6 +1356,7 @@ def __restore_timestamps_asr_model(self): model_restore_path, save_restore_connector=save_restore_connector ) external_timestamps_model.eval() + except Exception as e: raise RuntimeError( f"Error restoring external timestamps ASR model with timestamps_asr_model_config.yaml and timestamps_asr_model_weights.ckpt: {e}" diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index 488f5e621df1..7c611ba57709 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -42,7 +42,7 @@ class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin): def __init__(self, cfg: DictConfig, trainer=None): # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) - cfg = model_utils.maybe_update_config_version(cfg) + cfg = model_utils.maybe_update_config_version(cfg, make_copy=False) if 'tokenizer' not in cfg: raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index c0506dc880d1..0b4158797101 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -43,7 +43,7 @@ class EncDecHybridRNNTCTCBPEModel(EncDecHybridRNNTCTCModel, ASRBPEMixin): def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) - cfg = model_utils.maybe_update_config_version(cfg) + cfg = model_utils.maybe_update_config_version(cfg, make_copy=False) # Tokenizer is necessary for this model if 'tokenizer' not in cfg: diff --git a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py index 163dde5ecec1..f91d205f0760 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py @@ -79,7 +79,6 @@ def run_confidence_benchmark( draw_plot = plot_dir is not None if isinstance(plot_dir, str): plot_dir = Path(plot_dir) - is_rnnt = isinstance(model, EncDecRNNTModel) # transcribe audio with torch.amp.autocast(model.device.type, enabled=use_amp): @@ -87,8 +86,6 @@ def run_confidence_benchmark( transcriptions = model.transcribe( audio=filepaths, batch_size=batch_size, return_hypotheses=True, num_workers=num_workers ) - if is_rnnt: - transcriptions = transcriptions[0] levels = [] if target_level != "word": diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 0829008409df..fab70dd1095e 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -44,7 +44,7 @@ from nemo.core.config.templates.model_card import NEMO_DEFAULT_MODEL_CARD_TEMPLATE from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult -from nemo.utils import logging +from nemo.utils import logging, model_utils from nemo.utils.cloud import maybe_download_from_cloud from nemo.utils.data_utils import resolve_cache_dir from nemo.utils.model_utils import import_class_by_path, maybe_update_config_version @@ -590,11 +590,9 @@ def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = N """Instantiates object using DictConfig-based configuration""" # Resolve the config dict if isinstance(config, DictConfig): - config = OmegaConf.to_container(config, resolve=True) - config = OmegaConf.create(config) - OmegaConf.set_struct(config, True) + config = model_utils.convert_model_config_to_dict_config(config) - config = maybe_update_config_version(config) + config = maybe_update_config_version(config, make_copy=False) # Hydra 0.x API if ('cls' in config or 'target' in config) and 'params' in config: @@ -654,12 +652,8 @@ def to_config_dict(self) -> 'DictConfig': """Returns object's configuration to config dictionary""" if hasattr(self, '_cfg') and self._cfg is not None: # Resolve the config dict - if isinstance(self._cfg, DictConfig): - config = OmegaConf.to_container(self._cfg, resolve=True) - config = OmegaConf.create(config) - OmegaConf.set_struct(config, True) - - config = maybe_update_config_version(config) + config = model_utils.convert_model_config_to_dict_config(self._cfg) + config = maybe_update_config_version(config, make_copy=False) self._cfg = config @@ -747,7 +741,7 @@ def to_config_file(self, path2yaml_file: str): Returns: """ if hasattr(self, '_cfg'): - self._cfg = maybe_update_config_version(self._cfg) + self._cfg = maybe_update_config_version(self._cfg, make_copy=False) with open(path2yaml_file, 'w', encoding='utf-8') as fout: OmegaConf.save(config=self._cfg, f=fout, resolve=True) else: diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 89aa163bf062..3ede355392a9 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -118,7 +118,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): cfg = model_utils.convert_model_config_to_dict_config(cfg) # Convert config to support Hydra 1.0+ instantiation - cfg = model_utils.maybe_update_config_version(cfg) + cfg = model_utils.maybe_update_config_version(cfg, make_copy=False) if 'model' in cfg: raise ValueError( diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index 2ac3922d5fef..8a09c3d0875d 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -20,7 +20,7 @@ import tempfile import time import uuid -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Callable, Generator, Optional, Set, Union import torch @@ -136,19 +136,19 @@ def load_config_and_state_dict( map_location = torch.device('cpu') app_state = AppState() - with tempfile.TemporaryDirectory() as tmpdir: - try: - # Check if self.model_extracted_dir is set, and is a valid path - if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir): - # Log that NeMo will use the provided `model_extracted_dir` - logging.info( - f"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`." - ) - # Override `tmpdir` above with the pre-extracted `model_extracted_dir` - tmpdir = self.model_extracted_dir + # Determine if we should use a pre-extracted directory + use_extracted_dir = self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir) - else: + if use_extracted_dir: + logging.info(f"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`.") + + # Use nullcontext if we have an extracted dir, otherwise create a temp directory + dir_context = nullcontext(self.model_extracted_dir) if use_extracted_dir else tempfile.TemporaryDirectory() + + with dir_context as tmpdir: + try: + if not use_extracted_dir: # Extract the nemo file into the temporary directory filter_fn = None if return_config: diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 93340eb0f2ba..ebff99d6c160 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -490,6 +490,7 @@ def convert_model_config_to_dict_config(cfg: Union['DictConfig', 'NemoConfig']) config = OmegaConf.to_container(cfg, resolve=True) config = OmegaConf.create(config) + return config @@ -508,13 +509,13 @@ def _convert_config(cfg: 'OmegaConf'): # Recursion. try: for _, sub_cfg in cfg.items(): - if isinstance(sub_cfg, DictConfig): + if isinstance(sub_cfg, (dict, DictConfig)): _convert_config(sub_cfg) except omegaconf_errors.OmegaConfBaseException as e: logging.warning(f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.") -def maybe_update_config_version(cfg: 'DictConfig'): +def maybe_update_config_version(cfg: 'DictConfig', make_copy: bool = True): """ Recursively convert Hydra 0.x configs to Hydra 1.x configs. @@ -525,6 +526,7 @@ def maybe_update_config_version(cfg: 'DictConfig'): Args: cfg: Any Hydra compatible DictConfig + make_copy: bool to indicating if the config should be copied before updating Returns: An updated DictConfig that conforms to Hydra 1.x format. @@ -537,14 +539,15 @@ def maybe_update_config_version(cfg: 'DictConfig'): # Cannot be cast to DictConfig, skip updating. return cfg - # Make a copy of model config. - cfg = copy.deepcopy(cfg) + # Make a copy if requested + if make_copy: + cfg = copy.deepcopy(cfg) + OmegaConf.set_struct(cfg, False) - # Convert config. + # Convert config _convert_config(cfg) - # Update model config. OmegaConf.set_struct(cfg, True) return cfg diff --git a/tests/collections/asr/confidence/test_asr_confidence.py b/tests/collections/asr/confidence/test_asr_confidence.py index 89beb61f50bf..b46119fa1d6a 100644 --- a/tests/collections/asr/confidence/test_asr_confidence.py +++ b/tests/collections/asr/confidence/test_asr_confidence.py @@ -19,43 +19,27 @@ import numpy as np import pytest -from lightning.pytorch import Trainer from omegaconf import OmegaConf -from nemo.collections.asr.models import ASRModel, EncDecCTCModelBPE, EncDecRNNTBPEModel +from nemo.collections.asr.models import ASRModel, EncDecMultiTaskModel from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig from nemo.collections.asr.parts.submodules.ctc_greedy_decoding import GreedyCTCInferConfig +from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecodingConfig +from nemo.collections.asr.parts.submodules.multitask_greedy_decoding import AEDGreedyInferConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig from nemo.collections.asr.parts.utils.asr_confidence_benchmarking_utils import run_confidence_benchmark from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig # both models recognize the test data without errors, thus every metric except ece return default values -ECE_VALUES = {("token", "ctc"): 0.87, ("token", "rnnt"): 0.82, ("word", "ctc"): 0.91, ("word", "rnnt"): 0.88} +# ECE values for fast conformer models (stt_en_fastconformer_ctc_large and stt_en_fastconformer_transducer_large) +ECE_VALUES = {("token", "ctc"): 0.86, ("token", "rnnt"): 0.75, ("word", "ctc"): 0.89, ("word", "rnnt"): 0.80} TOL_DEGREE = 2 -TOL = 1 / math.pow(10, TOL_DEGREE) +TOL = 2 / math.pow(10, TOL_DEGREE) @pytest.fixture(scope="module") -def conformer_ctc_bpe_model(): - model = EncDecCTCModelBPE.from_pretrained(model_name="stt_en_conformer_ctc_small") - model.set_trainer(Trainer(devices=1, accelerator="cpu")) - model = model.eval() - return model - - -@pytest.fixture(scope="module") -def conformer_rnnt_bpe_model(): - model = EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_small") - model.set_trainer(Trainer(devices=1, accelerator="cpu")) - model = model.eval() - return model - - -@pytest.mark.with_downloads -@pytest.fixture(scope="module") -# @pytest.fixture def audio_and_texts(test_data_dir): # get filenames and reference texts from manifest filepaths = [] @@ -72,24 +56,28 @@ def audio_and_texts(test_data_dir): class TestASRConfidenceBenchmark: - @pytest.mark.pleasefixme @pytest.mark.integration @pytest.mark.with_downloads @pytest.mark.parametrize('model_name', ("ctc", "rnnt")) @pytest.mark.parametrize('target_level', ("token", "word")) def test_run_confidence_benchmark( - self, model_name, target_level, audio_and_texts, conformer_ctc_bpe_model, conformer_rnnt_bpe_model + self, model_name, target_level, audio_and_texts, fast_conformer_ctc_model, fast_conformer_transducer_model ): - model = conformer_ctc_bpe_model if model_name == "ctc" else conformer_rnnt_bpe_model + model = fast_conformer_ctc_model if model_name == "ctc" else fast_conformer_transducer_model assert isinstance(model, ASRModel) filepaths, reference_texts = audio_and_texts confidence_cfg = ( - ConfidenceConfig(preserve_token_confidence=True) + ConfidenceConfig(preserve_frame_confidence=True, preserve_token_confidence=True) if target_level == "token" - else ConfidenceConfig(preserve_word_confidence=True) + else ConfidenceConfig(preserve_frame_confidence=True, preserve_word_confidence=True) ) model.change_decoding_strategy( - RNNTDecodingConfig(fused_batch_size=-1, strategy="greedy_batch", confidence_cfg=confidence_cfg) + RNNTDecodingConfig( + fused_batch_size=-1, + strategy="greedy_batch", + confidence_cfg=confidence_cfg, + greedy=GreedyBatchedRNNTInferConfig(loop_labels=False), + ) if model_name == "rnnt" else CTCDecodingConfig(confidence_cfg=confidence_cfg) ) @@ -104,13 +92,12 @@ def test_run_confidence_benchmark( atol=TOL, ) - @pytest.mark.pleasefixme @pytest.mark.integration @pytest.mark.with_downloads @pytest.mark.parametrize('model_name', ("ctc", "rnnt")) - def test_deprecated_config_args(self, model_name, conformer_ctc_bpe_model, conformer_rnnt_bpe_model): + def test_deprecated_config_args(self, model_name, fast_conformer_ctc_model, fast_conformer_transducer_model): assert ConfidenceConfig().method_cfg.alpha == 0.33, "default `alpha` is supposed to be 0.33" - model = conformer_ctc_bpe_model if model_name == "ctc" else conformer_rnnt_bpe_model + model = fast_conformer_ctc_model if model_name == "ctc" else fast_conformer_transducer_model assert isinstance(model, ASRModel) conf = OmegaConf.create({"temperature": 0.5}) @@ -133,3 +120,50 @@ def test_deprecated_config_args(self, model_name, conformer_ctc_bpe_model, confo else CTCDecodingConfig(greedy=GreedyCTCInferConfig(preserve_frame_confidence=True, **test_args_greedy)) ) assert model.cfg.decoding.greedy.confidence_method_cfg.alpha == 0.5 + + @pytest.mark.unit + def test_aed_multitask_model_confidence(self, canary_1b_v2, test_data_dir): + """Test token and word confidence for AED multitask models (Canary).""" + model = canary_1b_v2 + assert isinstance(model, EncDecMultiTaskModel) + + audio_file = Path(test_data_dir) / "asr" / "train" / "an4" / "wav" / "an46-mmap-b.wav" + + # Configure decoding with confidence + decode_cfg = MultiTaskDecodingConfig( + strategy='greedy', + greedy=AEDGreedyInferConfig(preserve_token_confidence=True), + confidence_cfg=ConfidenceConfig(preserve_token_confidence=True, preserve_word_confidence=True), + ) + model.change_decoding_strategy(decode_cfg) + + hypotheses = model.transcribe( + audio=str(audio_file), + batch_size=1, + return_hypotheses=True, + ) + + assert len(hypotheses) == 1 + hyp = hypotheses[0] + + # Verify text is present + assert isinstance(hyp.text, str) + assert len(hyp.text) > 0 + + # Verify y_sequence is present + assert hyp.y_sequence is not None + assert len(hyp.y_sequence) > 0 + + # Verify token confidence is present and has correct length + assert hyp.token_confidence is not None + assert len(hyp.token_confidence) == len(hyp.y_sequence) + + # Verify word confidence is present + assert hyp.word_confidence is not None + assert len(hyp.word_confidence) > 0 + + # Verify confidence values are in valid range [0, 1] + for conf in hyp.token_confidence: + assert 0.0 <= conf <= 1.0, f"Token confidence {conf} not in valid range [0, 1]" + for conf in hyp.word_confidence: + assert 0.0 <= conf <= 1.0, f"Word confidence {conf} not in valid range [0, 1]" diff --git a/tests/collections/asr/mixins/test_transcription.py b/tests/collections/asr/mixins/test_transcription.py index 036c558f5087..255f0304149f 100644 --- a/tests/collections/asr/mixins/test_transcription.py +++ b/tests/collections/asr/mixins/test_transcription.py @@ -79,7 +79,6 @@ def __len__(self): return len(self.audio_tensors) -@pytest.mark.with_downloads() @pytest.fixture() def audio_files(test_data_dir): """