1414# limitations under the License.
1515"""BARK model configuration"""
1616
17- import os
18- from typing import Dict , Optional , Union
17+ from typing import Dict
1918
2019from ...configuration_utils import PretrainedConfig
2120from ...utils import add_start_docstrings , logging
22- from ..auto import CONFIG_MAPPING
21+ from ..auto import CONFIG_MAPPING , AutoConfig
2322
2423
2524logger = logging .get_logger (__name__ )
6463
6564
6665class BarkSubModelConfig (PretrainedConfig ):
67- model_type = "bark_module"
6866 keys_to_ignore_at_inference = ["past_key_values" ]
6967
7068 attribute_map = {
@@ -101,38 +99,6 @@ def __init__(
10199
102100 super ().__init__ (** kwargs )
103101
104- @classmethod
105- def from_pretrained (
106- cls ,
107- pretrained_model_name_or_path : Union [str , os .PathLike ],
108- cache_dir : Optional [Union [str , os .PathLike ]] = None ,
109- force_download : bool = False ,
110- local_files_only : bool = False ,
111- token : Optional [Union [str , bool ]] = None ,
112- revision : str = "main" ,
113- ** kwargs ,
114- ) -> "PretrainedConfig" :
115- kwargs ["cache_dir" ] = cache_dir
116- kwargs ["force_download" ] = force_download
117- kwargs ["local_files_only" ] = local_files_only
118- kwargs ["revision" ] = revision
119-
120- cls ._set_token_in_kwargs (kwargs , token )
121-
122- config_dict , kwargs = cls .get_config_dict (pretrained_model_name_or_path , ** kwargs )
123-
124- # get the config dict if we are loading from Bark
125- if config_dict .get ("model_type" ) == "bark" :
126- config_dict = config_dict [f"{ cls .model_type } _config" ]
127-
128- if "model_type" in config_dict and hasattr (cls , "model_type" ) and config_dict ["model_type" ] != cls .model_type :
129- logger .warning (
130- f"You are using a model of type { config_dict ['model_type' ]} to instantiate a model of type "
131- f"{ cls .model_type } . This is not supported for all configurations of models and can yield errors."
132- )
133-
134- return cls .from_dict (config_dict , ** kwargs )
135-
136102
137103@add_start_docstrings (
138104 BARK_SUBMODELCONFIG_START_DOCSTRING .format (config = "BarkSemanticConfig" , model = "BarkSemanticModel" ),
@@ -154,6 +120,7 @@ def from_pretrained(
154120)
155121class BarkSemanticConfig (BarkSubModelConfig ):
156122 model_type = "semantic"
123+ base_config_key = "semantic_config"
157124
158125
159126@add_start_docstrings (
@@ -176,6 +143,7 @@ class BarkSemanticConfig(BarkSubModelConfig):
176143)
177144class BarkCoarseConfig (BarkSubModelConfig ):
178145 model_type = "coarse_acoustics"
146+ base_config_key = "coarse_acoustics_config"
179147
180148
181149@add_start_docstrings (
@@ -203,6 +171,7 @@ class BarkCoarseConfig(BarkSubModelConfig):
203171)
204172class BarkFineConfig (BarkSubModelConfig ):
205173 model_type = "fine_acoustics"
174+ base_config_key = "fine_acoustics_config"
206175
207176 def __init__ (self , tie_word_embeddings = True , n_codes_total = 8 , n_codes_given = 1 , ** kwargs ):
208177 self .n_codes_total = n_codes_total
@@ -265,6 +234,12 @@ class BarkConfig(PretrainedConfig):
265234 """
266235
267236 model_type = "bark"
237+ sub_configs = {
238+ "semantic_config" : BarkSemanticConfig ,
239+ "coarse_acoustics_config" : BarkCoarseConfig ,
240+ "fine_acoustics_config" : BarkFineConfig ,
241+ "codec_config" : AutoConfig ,
242+ }
268243
269244 def __init__ (
270245 self ,
0 commit comments