Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)
cfg = model_utils.maybe_update_config_version(cfg, make_copy=False)
_config_check(cfg)

self.prompt_format = cfg.prompt_format
Expand Down Expand Up @@ -1320,18 +1320,34 @@ def __restore_timestamps_asr_model(self):
The config and weights are expected to be in the main .nemo file and be named `timestamps_asr_model_config.yaml` and `timestamps_asr_model_weights.ckpt` respectively.
"""
app_state = AppState()
nemo_file_folder = app_state.nemo_file_folder # Already-extracted temp directory
model_restore_path = app_state.model_restore_path

if not model_restore_path:
return None

save_restore_connector = SaveRestoreConnector()
save_restore_connector.model_config_yaml = os.path.join(nemo_file_folder, "timestamps_asr_model_config.yaml")
save_restore_connector.model_weights_ckpt = os.path.join(nemo_file_folder, "timestamps_asr_model_weights.ckpt")

filter_fn = lambda name: "timestamps_asr_model" in name
members = save_restore_connector._filtered_tar_info(model_restore_path, filter_fn=filter_fn)
# Check if the model_restore_path is already an extracted directory (which happens during restore_from)
# If so, use it directly to avoid double extraction
if app_state.nemo_file_folder and os.path.isdir(app_state.nemo_file_folder):
# Verify that the timestamp model components exist in the extracted folder
config_exists = os.path.exists(save_restore_connector.model_config_yaml)
weights_exists = os.path.exists(save_restore_connector.model_weights_ckpt)

if not members:
return None
if not (config_exists and weights_exists):
return None

save_restore_connector.model_extracted_dir = app_state.nemo_file_folder

else:
filter_fn = lambda name: "timestamps_asr_model" in name
members = save_restore_connector._filtered_tar_info(model_restore_path, filter_fn=filter_fn)

if not members:
return None

try:
save_restore_connector.model_config_yaml = "timestamps_asr_model_config.yaml"
Expand All @@ -1340,6 +1356,7 @@ def __restore_timestamps_asr_model(self):
model_restore_path, save_restore_connector=save_restore_connector
)
external_timestamps_model.eval()

except Exception as e:
raise RuntimeError(
f"Error restoring external timestamps ASR model with timestamps_asr_model_config.yaml and timestamps_asr_model_weights.ckpt: {e}"
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin):
def __init__(self, cfg: DictConfig, trainer=None):
# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)
cfg = model_utils.maybe_update_config_version(cfg, make_copy=False)

if 'tokenizer' not in cfg:
raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !")
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class EncDecHybridRNNTCTCBPEModel(EncDecHybridRNNTCTCModel, ASRBPEMixin):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)
cfg = model_utils.maybe_update_config_version(cfg, make_copy=False)

# Tokenizer is necessary for this model
if 'tokenizer' not in cfg:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,13 @@ def run_confidence_benchmark(
draw_plot = plot_dir is not None
if isinstance(plot_dir, str):
plot_dir = Path(plot_dir)
is_rnnt = isinstance(model, EncDecRNNTModel)

# transcribe audio
with torch.amp.autocast(model.device.type, enabled=use_amp):
with torch.no_grad():
transcriptions = model.transcribe(
audio=filepaths, batch_size=batch_size, return_hypotheses=True, num_workers=num_workers
)
if is_rnnt:
transcriptions = transcriptions[0]

levels = []
if target_level != "word":
Expand Down
18 changes: 6 additions & 12 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from nemo.core.config.templates.model_card import NEMO_DEFAULT_MODEL_CARD_TEMPLATE
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
from nemo.utils import logging
from nemo.utils import logging, model_utils
from nemo.utils.cloud import maybe_download_from_cloud
from nemo.utils.data_utils import resolve_cache_dir
from nemo.utils.model_utils import import_class_by_path, maybe_update_config_version
Expand Down Expand Up @@ -590,11 +590,9 @@ def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = N
"""Instantiates object using DictConfig-based configuration"""
# Resolve the config dict
if isinstance(config, DictConfig):
config = OmegaConf.to_container(config, resolve=True)
config = OmegaConf.create(config)
OmegaConf.set_struct(config, True)
config = model_utils.convert_model_config_to_dict_config(config)

config = maybe_update_config_version(config)
config = maybe_update_config_version(config, make_copy=False)

# Hydra 0.x API
if ('cls' in config or 'target' in config) and 'params' in config:
Expand Down Expand Up @@ -654,12 +652,8 @@ def to_config_dict(self) -> 'DictConfig':
"""Returns object's configuration to config dictionary"""
if hasattr(self, '_cfg') and self._cfg is not None:
# Resolve the config dict
if isinstance(self._cfg, DictConfig):
config = OmegaConf.to_container(self._cfg, resolve=True)
config = OmegaConf.create(config)
OmegaConf.set_struct(config, True)

config = maybe_update_config_version(config)
config = model_utils.convert_model_config_to_dict_config(self._cfg)
config = maybe_update_config_version(config, make_copy=False)

self._cfg = config

Expand Down Expand Up @@ -747,7 +741,7 @@ def to_config_file(self, path2yaml_file: str):
Returns:
"""
if hasattr(self, '_cfg'):
self._cfg = maybe_update_config_version(self._cfg)
self._cfg = maybe_update_config_version(self._cfg, make_copy=False)
with open(path2yaml_file, 'w', encoding='utf-8') as fout:
OmegaConf.save(config=self._cfg, f=fout, resolve=True)
else:
Expand Down
2 changes: 1 addition & 1 deletion nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg = model_utils.convert_model_config_to_dict_config(cfg)

# Convert config to support Hydra 1.0+ instantiation
cfg = model_utils.maybe_update_config_version(cfg)
cfg = model_utils.maybe_update_config_version(cfg, make_copy=False)

if 'model' in cfg:
raise ValueError(
Expand Down
24 changes: 12 additions & 12 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tempfile
import time
import uuid
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Callable, Generator, Optional, Set, Union

import torch
Expand Down Expand Up @@ -136,19 +136,19 @@ def load_config_and_state_dict(
map_location = torch.device('cpu')

app_state = AppState()
with tempfile.TemporaryDirectory() as tmpdir:
try:
# Check if self.model_extracted_dir is set, and is a valid path
if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir):
# Log that NeMo will use the provided `model_extracted_dir`
logging.info(
f"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`."
)

# Override `tmpdir` above with the pre-extracted `model_extracted_dir`
tmpdir = self.model_extracted_dir
# Determine if we should use a pre-extracted directory
use_extracted_dir = self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir)

else:
if use_extracted_dir:
logging.info(f"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`.")

# Use nullcontext if we have an extracted dir, otherwise create a temp directory
dir_context = nullcontext(self.model_extracted_dir) if use_extracted_dir else tempfile.TemporaryDirectory()

with dir_context as tmpdir:
try:
if not use_extracted_dir:
# Extract the nemo file into the temporary directory
filter_fn = None
if return_config:
Expand Down
15 changes: 9 additions & 6 deletions nemo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def convert_model_config_to_dict_config(cfg: Union['DictConfig', 'NemoConfig'])

config = OmegaConf.to_container(cfg, resolve=True)
config = OmegaConf.create(config)

return config


Expand All @@ -508,13 +509,13 @@ def _convert_config(cfg: 'OmegaConf'):
# Recursion.
try:
for _, sub_cfg in cfg.items():
if isinstance(sub_cfg, DictConfig):
if isinstance(sub_cfg, (dict, DictConfig)):
_convert_config(sub_cfg)
except omegaconf_errors.OmegaConfBaseException as e:
logging.warning(f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")


def maybe_update_config_version(cfg: 'DictConfig'):
def maybe_update_config_version(cfg: 'DictConfig', make_copy: bool = True):
"""
Recursively convert Hydra 0.x configs to Hydra 1.x configs.

Expand All @@ -525,6 +526,7 @@ def maybe_update_config_version(cfg: 'DictConfig'):

Args:
cfg: Any Hydra compatible DictConfig
make_copy: bool to indicating if the config should be copied before updating

Returns:
An updated DictConfig that conforms to Hydra 1.x format.
Expand All @@ -537,14 +539,15 @@ def maybe_update_config_version(cfg: 'DictConfig'):
# Cannot be cast to DictConfig, skip updating.
return cfg

# Make a copy of model config.
cfg = copy.deepcopy(cfg)
# Make a copy if requested
if make_copy:
cfg = copy.deepcopy(cfg)

OmegaConf.set_struct(cfg, False)

# Convert config.
# Convert config
_convert_config(cfg)

# Update model config.
OmegaConf.set_struct(cfg, True)

return cfg
Expand Down
96 changes: 65 additions & 31 deletions tests/collections/asr/confidence/test_asr_confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,27 @@

import numpy as np
import pytest
from lightning.pytorch import Trainer
from omegaconf import OmegaConf

from nemo.collections.asr.models import ASRModel, EncDecCTCModelBPE, EncDecRNNTBPEModel
from nemo.collections.asr.models import ASRModel, EncDecMultiTaskModel
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
from nemo.collections.asr.parts.submodules.ctc_greedy_decoding import GreedyCTCInferConfig
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecodingConfig
from nemo.collections.asr.parts.submodules.multitask_greedy_decoding import AEDGreedyInferConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig
from nemo.collections.asr.parts.utils.asr_confidence_benchmarking_utils import run_confidence_benchmark
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig

# both models recognize the test data without errors, thus every metric except ece return default values
ECE_VALUES = {("token", "ctc"): 0.87, ("token", "rnnt"): 0.82, ("word", "ctc"): 0.91, ("word", "rnnt"): 0.88}
# ECE values for fast conformer models (stt_en_fastconformer_ctc_large and stt_en_fastconformer_transducer_large)
ECE_VALUES = {("token", "ctc"): 0.86, ("token", "rnnt"): 0.75, ("word", "ctc"): 0.89, ("word", "rnnt"): 0.80}

TOL_DEGREE = 2
TOL = 1 / math.pow(10, TOL_DEGREE)
TOL = 2 / math.pow(10, TOL_DEGREE)


@pytest.fixture(scope="module")
def conformer_ctc_bpe_model():
model = EncDecCTCModelBPE.from_pretrained(model_name="stt_en_conformer_ctc_small")
model.set_trainer(Trainer(devices=1, accelerator="cpu"))
model = model.eval()
return model


@pytest.fixture(scope="module")
def conformer_rnnt_bpe_model():
model = EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_small")
model.set_trainer(Trainer(devices=1, accelerator="cpu"))
model = model.eval()
return model


@pytest.mark.with_downloads
@pytest.fixture(scope="module")
# @pytest.fixture
def audio_and_texts(test_data_dir):
# get filenames and reference texts from manifest
filepaths = []
Expand All @@ -72,24 +56,28 @@ def audio_and_texts(test_data_dir):


class TestASRConfidenceBenchmark:
@pytest.mark.pleasefixme
@pytest.mark.integration
@pytest.mark.with_downloads
@pytest.mark.parametrize('model_name', ("ctc", "rnnt"))
@pytest.mark.parametrize('target_level', ("token", "word"))
def test_run_confidence_benchmark(
self, model_name, target_level, audio_and_texts, conformer_ctc_bpe_model, conformer_rnnt_bpe_model
self, model_name, target_level, audio_and_texts, fast_conformer_ctc_model, fast_conformer_transducer_model
):
model = conformer_ctc_bpe_model if model_name == "ctc" else conformer_rnnt_bpe_model
model = fast_conformer_ctc_model if model_name == "ctc" else fast_conformer_transducer_model
assert isinstance(model, ASRModel)
filepaths, reference_texts = audio_and_texts
confidence_cfg = (
ConfidenceConfig(preserve_token_confidence=True)
ConfidenceConfig(preserve_frame_confidence=True, preserve_token_confidence=True)
if target_level == "token"
else ConfidenceConfig(preserve_word_confidence=True)
else ConfidenceConfig(preserve_frame_confidence=True, preserve_word_confidence=True)
)
model.change_decoding_strategy(
RNNTDecodingConfig(fused_batch_size=-1, strategy="greedy_batch", confidence_cfg=confidence_cfg)
RNNTDecodingConfig(
fused_batch_size=-1,
strategy="greedy_batch",
confidence_cfg=confidence_cfg,
greedy=GreedyBatchedRNNTInferConfig(loop_labels=False),
)
if model_name == "rnnt"
else CTCDecodingConfig(confidence_cfg=confidence_cfg)
)
Expand All @@ -104,13 +92,12 @@ def test_run_confidence_benchmark(
atol=TOL,
)

@pytest.mark.pleasefixme
@pytest.mark.integration
@pytest.mark.with_downloads
@pytest.mark.parametrize('model_name', ("ctc", "rnnt"))
def test_deprecated_config_args(self, model_name, conformer_ctc_bpe_model, conformer_rnnt_bpe_model):
def test_deprecated_config_args(self, model_name, fast_conformer_ctc_model, fast_conformer_transducer_model):
assert ConfidenceConfig().method_cfg.alpha == 0.33, "default `alpha` is supposed to be 0.33"
model = conformer_ctc_bpe_model if model_name == "ctc" else conformer_rnnt_bpe_model
model = fast_conformer_ctc_model if model_name == "ctc" else fast_conformer_transducer_model
assert isinstance(model, ASRModel)

conf = OmegaConf.create({"temperature": 0.5})
Expand All @@ -133,3 +120,50 @@ def test_deprecated_config_args(self, model_name, conformer_ctc_bpe_model, confo
else CTCDecodingConfig(greedy=GreedyCTCInferConfig(preserve_frame_confidence=True, **test_args_greedy))
)
assert model.cfg.decoding.greedy.confidence_method_cfg.alpha == 0.5

@pytest.mark.unit
def test_aed_multitask_model_confidence(self, canary_1b_v2, test_data_dir):
"""Test token and word confidence for AED multitask models (Canary)."""
model = canary_1b_v2
assert isinstance(model, EncDecMultiTaskModel)

audio_file = Path(test_data_dir) / "asr" / "train" / "an4" / "wav" / "an46-mmap-b.wav"

# Configure decoding with confidence
decode_cfg = MultiTaskDecodingConfig(
strategy='greedy',
greedy=AEDGreedyInferConfig(preserve_token_confidence=True),
confidence_cfg=ConfidenceConfig(preserve_token_confidence=True, preserve_word_confidence=True),
)
model.change_decoding_strategy(decode_cfg)

hypotheses = model.transcribe(
audio=str(audio_file),
batch_size=1,
return_hypotheses=True,
)

assert len(hypotheses) == 1
hyp = hypotheses[0]

# Verify text is present
assert isinstance(hyp.text, str)
assert len(hyp.text) > 0

# Verify y_sequence is present
assert hyp.y_sequence is not None
assert len(hyp.y_sequence) > 0

# Verify token confidence is present and has correct length
assert hyp.token_confidence is not None
assert len(hyp.token_confidence) == len(hyp.y_sequence)

# Verify word confidence is present
assert hyp.word_confidence is not None
assert len(hyp.word_confidence) > 0

# Verify confidence values are in valid range [0, 1]
for conf in hyp.token_confidence:
assert 0.0 <= conf <= 1.0, f"Token confidence {conf} not in valid range [0, 1]"
for conf in hyp.word_confidence:
assert 0.0 <= conf <= 1.0, f"Word confidence {conf} not in valid range [0, 1]"
1 change: 0 additions & 1 deletion tests/collections/asr/mixins/test_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __len__(self):
return len(self.audio_tensors)


@pytest.mark.with_downloads()
@pytest.fixture()
def audio_files(test_data_dir):
"""
Expand Down
Loading