Skip to content

Commit e765ed0

Browse files
committed
Apply isort and black reformatting
Signed-off-by: Edresson <Edresson@users.noreply.github.com>
1 parent c05392e commit e765ed0

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

nemo/collections/speechlm2/models/duplex_ear_tts.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,13 @@ def setup_rvq_audio_codec(model):
321321
with fp32_precision():
322322
if model.cfg.get("pretrained_ae_dir", None):
323323
model.audio_codec = (
324-
RVQVAEModel.from_pretrained(model.cfg.pretrained_ae_dir, cfg=DictConfig(model.cfg.codec_config) if model.cfg.get("codec_config", None) else None, strict=False).eval().to(model.device)
324+
RVQVAEModel.from_pretrained(
325+
model.cfg.pretrained_ae_dir,
326+
cfg=DictConfig(model.cfg.codec_config) if model.cfg.get("codec_config", None) else None,
327+
strict=False,
328+
)
329+
.eval()
330+
.to(model.device)
325331
)
326332
else:
327333
# init codec from config

nemo/collections/speechlm2/modules/ear_tts_commons.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from collections.abc import Mapping
2525
from typing import Any
2626

27+
from omegaconf import DictConfig
2728
from safetensors import safe_open
2829
from torch import nn
2930

3031
from nemo.utils import logging
31-
from omegaconf import DictConfig
3232

3333
# ==============================================================================
3434
# Contants
@@ -152,6 +152,7 @@ def get_config_from_dir(workdir_path: str) -> DictConfig:
152152
# Base Model Classes
153153
# ==============================================================================
154154

155+
155156
class PreTrainedModel(nn.Module):
156157
config_class = DictConfig
157158

nemo/collections/speechlm2/modules/rvq_ear_tts_model.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222
import transformers
23+
from omegaconf import DictConfig, OmegaConf
2324
from torch import Tensor, nn
2425
from torch.nn import functional as F
2526
from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding, AutoTokenizer, Cache
@@ -29,7 +30,6 @@
2930
from nemo.collections.speechlm2.parts.precision import fp32_precision
3031
from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init
3132
from nemo.utils import logging
32-
from omegaconf import DictConfig, OmegaConf
3333

3434
# ==============================================================================
3535
# MLP module and Norm
@@ -894,7 +894,9 @@ def __init__(
894894

895895
# 2. Initialize the backbone model
896896
if backbone_type:
897-
config = AutoConfig.for_model(backbone_type, **(OmegaConf.to_container(backbone_config, resolve=True) if backbone_config else {}))
897+
config = AutoConfig.for_model(
898+
backbone_type, **(OmegaConf.to_container(backbone_config, resolve=True) if backbone_config else {})
899+
)
898900
self.backbone = AutoModelForTextEncoding.from_config(config)
899901
else:
900902
assert backbone_model_class and backbone_config_class
@@ -1044,12 +1046,12 @@ class RVQEARTTSModel(PreTrainedModel):
10441046
Args:
10451047
config (DictConfig | dict[str, Any]): The configuration object for the model.
10461048
"""
1049+
10471050
rvq_embs: Tensor
10481051

10491052
def __init__(self, config: DictConfig | dict[str, Any]):
10501053
super().__init__(config)
10511054

1052-
10531055
# Backbone module
10541056
if self.config.get("pretrained_text_name", None):
10551057
# Load pretrained backbone from huggingface
@@ -1059,15 +1061,26 @@ def __init__(self, config: DictConfig | dict[str, Any]):
10591061
self.backbone = llm.model # fetch PretrainedBaseModel from model "ForCausalLM"
10601062
else:
10611063
if self.config.get("backbone_type", None) is None:
1062-
assert self.config.get("backbone_model_class", None) is not None and self.config.get("backbone_config_class", None) is not None
1064+
assert (
1065+
self.config.get("backbone_model_class", None) is not None
1066+
and self.config.get("backbone_config_class", None) is not None
1067+
)
10631068
backbone_config = getattr(transformers, self.config.backbone_config_class)(
1064-
**(OmegaConf.to_container(self.config.backbone_config, resolve=True) if self.config.backbone_config else {}),
1069+
**(
1070+
OmegaConf.to_container(self.config.backbone_config, resolve=True)
1071+
if self.config.backbone_config
1072+
else {}
1073+
),
10651074
)
10661075
self.backbone = getattr(transformers, self.config.backbone_model_class)(backbone_config)
10671076
else:
10681077
backbone_config = AutoConfig.for_model(
10691078
self.config.backbone_type,
1070-
**(OmegaConf.to_container(self.config.backbone_config, resolve=True) if self.config.backbone_config else {}),
1079+
**(
1080+
OmegaConf.to_container(self.config.backbone_config, resolve=True)
1081+
if self.config.backbone_config
1082+
else {}
1083+
),
10711084
)
10721085
self.backbone = AutoModel.from_config(backbone_config)
10731086

nemo/collections/speechlm2/modules/rvq_ear_tts_vae.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020

2121
# Third-party
2222
import torch
23+
from omegaconf import DictConfig
2324
from torch import Tensor, nn
2425
from torch.nn import functional as F
2526
from torchaudio import functional as ta_F
2627

2728
# Project
2829
from nemo.collections.speechlm2.modules.ear_tts_commons import PreTrainedModel
29-
from omegaconf import DictConfig
30+
3031

3132
@contextmanager
3233
def disable_tf32():
@@ -37,6 +38,7 @@ def disable_tf32():
3738
finally:
3839
torch.backends.cudnn.allow_tf32 = prev
3940

41+
4042
# ==============================================================================
4143
# Utility Functions
4244
# ==============================================================================

0 commit comments

Comments
 (0)