2020
2121import torch
2222import transformers
23+ from omegaconf import DictConfig , OmegaConf
2324from torch import Tensor , nn
2425from torch .nn import functional as F
2526from transformers import AutoConfig , AutoModel , AutoModelForTextEncoding , AutoTokenizer , Cache
2930from nemo .collections .speechlm2 .parts .precision import fp32_precision
3031from nemo .collections .speechlm2 .parts .pretrained import set_model_dict_for_partial_init
3132from nemo .utils import logging
32- from omegaconf import DictConfig , OmegaConf
3333
3434# ==============================================================================
3535# MLP module and Norm
@@ -894,7 +894,9 @@ def __init__(
894894
895895 # 2. Initialize the backbone model
896896 if backbone_type :
897- config = AutoConfig .for_model (backbone_type , ** (OmegaConf .to_container (backbone_config , resolve = True ) if backbone_config else {}))
897+ config = AutoConfig .for_model (
898+ backbone_type , ** (OmegaConf .to_container (backbone_config , resolve = True ) if backbone_config else {})
899+ )
898900 self .backbone = AutoModelForTextEncoding .from_config (config )
899901 else :
900902 assert backbone_model_class and backbone_config_class
@@ -1044,12 +1046,12 @@ class RVQEARTTSModel(PreTrainedModel):
10441046 Args:
10451047 config (DictConfig | dict[str, Any]): The configuration object for the model.
10461048 """
1049+
10471050 rvq_embs : Tensor
10481051
10491052 def __init__ (self , config : DictConfig | dict [str , Any ]):
10501053 super ().__init__ (config )
10511054
1052-
10531055 # Backbone module
10541056 if self .config .get ("pretrained_text_name" , None ):
10551057 # Load pretrained backbone from huggingface
@@ -1059,15 +1061,26 @@ def __init__(self, config: DictConfig | dict[str, Any]):
10591061 self .backbone = llm .model # fetch PretrainedBaseModel from model "ForCausalLM"
10601062 else :
10611063 if self .config .get ("backbone_type" , None ) is None :
1062- assert self .config .get ("backbone_model_class" , None ) is not None and self .config .get ("backbone_config_class" , None ) is not None
1064+ assert (
1065+ self .config .get ("backbone_model_class" , None ) is not None
1066+ and self .config .get ("backbone_config_class" , None ) is not None
1067+ )
10631068 backbone_config = getattr (transformers , self .config .backbone_config_class )(
1064- ** (OmegaConf .to_container (self .config .backbone_config , resolve = True ) if self .config .backbone_config else {}),
1069+ ** (
1070+ OmegaConf .to_container (self .config .backbone_config , resolve = True )
1071+ if self .config .backbone_config
1072+ else {}
1073+ ),
10651074 )
10661075 self .backbone = getattr (transformers , self .config .backbone_model_class )(backbone_config )
10671076 else :
10681077 backbone_config = AutoConfig .for_model (
10691078 self .config .backbone_type ,
1070- ** (OmegaConf .to_container (self .config .backbone_config , resolve = True ) if self .config .backbone_config else {}),
1079+ ** (
1080+ OmegaConf .to_container (self .config .backbone_config , resolve = True )
1081+ if self .config .backbone_config
1082+ else {}
1083+ ),
10711084 )
10721085 self .backbone = AutoModel .from_config (backbone_config )
10731086
0 commit comments