Skip to content

Commit 893ad04

Browse files
authored
Load sub-configs from composite configs (#34410)
* save/load sub-configs * nit forgot these * fix copies * move test to common * use dict for sub-configs * add load-save-laod test * clean up modeling check * oops this are correct keys * fix some tests, missed some composite configs * this model was missed
1 parent 5e1fd4e commit 893ad04

File tree

78 files changed

+464
-1052
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+464
-1052
lines changed

src/transformers/configuration_utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ class PretrainedConfig(PushToHubMixin):
190190
"""
191191

192192
model_type: str = ""
193+
base_config_key: str = ""
194+
sub_configs: Dict[str, "PretrainedConfig"] = {}
193195
is_composition: bool = False
194196
attribute_map: Dict[str, str] = {}
195197
_auto_class: Optional[str] = None
@@ -543,11 +545,22 @@ def from_pretrained(
543545
cls._set_token_in_kwargs(kwargs, token)
544546

545547
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
548+
if cls.base_config_key and cls.base_config_key in config_dict:
549+
config_dict = config_dict[cls.base_config_key]
550+
546551
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
547-
logger.warning(
548-
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
549-
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
550-
)
552+
# sometimes the config has no `base_config_key` if the config is used in several composite models
553+
# e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
554+
for k, v in config_dict.items():
555+
if isinstance(v, dict) and v.get("model_type") == cls.model_type:
556+
config_dict = v
557+
558+
# raise warning only if we still can't see a match in `model_type`
559+
if config_dict["model_type"] != cls.model_type:
560+
logger.warning(
561+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
562+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
563+
)
551564

552565
return cls.from_dict(config_dict, **kwargs)
553566

src/transformers/modeling_utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,15 +1608,14 @@ def _autoset_attn_implementation(
16081608
# Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
16091609
# Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
16101610
# If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
1611-
for key in config:
1612-
if isinstance(getattr(config, key), PretrainedConfig):
1613-
sub_config = getattr(config, key)
1614-
curr_attn_implementation = (
1615-
requested_attn_implementation
1616-
if not isinstance(requested_attn_implementation, dict)
1617-
else requested_attn_implementation.get(key, None)
1618-
)
1619-
sub_config._attn_implementation_internal = curr_attn_implementation
1611+
for key in config.sub_configs.keys():
1612+
sub_config = getattr(config, key)
1613+
curr_attn_implementation = (
1614+
requested_attn_implementation
1615+
if not isinstance(requested_attn_implementation, dict)
1616+
else requested_attn_implementation.get(key, None)
1617+
)
1618+
sub_config._attn_implementation_internal = curr_attn_implementation
16201619

16211620
if use_flash_attention_2:
16221621
logger.warning_once(

src/transformers/models/align/configuration_align.py

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# limitations under the License.
1515
"""ALIGN model configuration"""
1616

17-
import os
18-
from typing import TYPE_CHECKING, List, Union
17+
from typing import TYPE_CHECKING, List
1918

2019

2120
if TYPE_CHECKING:
@@ -95,6 +94,7 @@ class AlignTextConfig(PretrainedConfig):
9594
```"""
9695

9796
model_type = "align_text_model"
97+
base_config_key = "text_config"
9898

9999
def __init__(
100100
self,
@@ -133,24 +133,6 @@ def __init__(
133133
self.use_cache = use_cache
134134
self.pad_token_id = pad_token_id
135135

136-
@classmethod
137-
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
138-
cls._set_token_in_kwargs(kwargs)
139-
140-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
141-
142-
# get the text config dict if we are loading from AlignConfig
143-
if config_dict.get("model_type") == "align":
144-
config_dict = config_dict["text_config"]
145-
146-
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
147-
logger.warning(
148-
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
149-
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
150-
)
151-
152-
return cls.from_dict(config_dict, **kwargs)
153-
154136

155137
class AlignVisionConfig(PretrainedConfig):
156138
r"""
@@ -223,6 +205,7 @@ class AlignVisionConfig(PretrainedConfig):
223205
```"""
224206

225207
model_type = "align_vision_model"
208+
base_config_key = "vision_config"
226209

227210
def __init__(
228211
self,
@@ -272,24 +255,6 @@ def __init__(
272255
self.drop_connect_rate = drop_connect_rate
273256
self.num_hidden_layers = sum(num_block_repeats) * 4
274257

275-
@classmethod
276-
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
277-
cls._set_token_in_kwargs(kwargs)
278-
279-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
280-
281-
# get the vision config dict if we are loading from AlignConfig
282-
if config_dict.get("model_type") == "align":
283-
config_dict = config_dict["vision_config"]
284-
285-
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
286-
logger.warning(
287-
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
288-
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
289-
)
290-
291-
return cls.from_dict(config_dict, **kwargs)
292-
293258

294259
class AlignConfig(PretrainedConfig):
295260
r"""
@@ -340,6 +305,7 @@ class AlignConfig(PretrainedConfig):
340305
```"""
341306

342307
model_type = "align"
308+
sub_configs = {"text_config": AlignTextConfig, "vision_config": AlignVisionConfig}
343309

344310
def __init__(
345311
self,

src/transformers/models/altclip/configuration_altclip.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
# limitations under the License.
1515
"""AltCLIP model configuration"""
1616

17-
import os
18-
from typing import Union
19-
2017
from ...configuration_utils import PretrainedConfig
2118
from ...utils import logging
2219

@@ -199,6 +196,7 @@ class AltCLIPVisionConfig(PretrainedConfig):
199196
```"""
200197

201198
model_type = "altclip_vision_model"
199+
base_config_key = "vision_config"
202200

203201
def __init__(
204202
self,
@@ -233,24 +231,6 @@ def __init__(
233231
self.layer_norm_eps = layer_norm_eps
234232
self.hidden_act = hidden_act
235233

236-
@classmethod
237-
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
238-
cls._set_token_in_kwargs(kwargs)
239-
240-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
241-
242-
# get the vision config dict if we are loading from AltCLIPConfig
243-
if config_dict.get("model_type") == "altclip":
244-
config_dict = config_dict["vision_config"]
245-
246-
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
247-
logger.warning(
248-
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
249-
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
250-
)
251-
252-
return cls.from_dict(config_dict, **kwargs)
253-
254234

255235
class AltCLIPConfig(PretrainedConfig):
256236
r"""
@@ -298,6 +278,7 @@ class AltCLIPConfig(PretrainedConfig):
298278
```"""
299279

300280
model_type = "altclip"
281+
sub_configs = {"text_config": AltCLIPTextConfig, "vision_config": AltCLIPVisionConfig}
301282

302283
def __init__(
303284
self, text_config=None, vision_config=None, projection_dim=768, logit_scale_init_value=2.6592, **kwargs

src/transformers/models/bark/configuration_bark.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
# limitations under the License.
1515
"""BARK model configuration"""
1616

17-
import os
18-
from typing import Dict, Optional, Union
17+
from typing import Dict
1918

2019
from ...configuration_utils import PretrainedConfig
2120
from ...utils import add_start_docstrings, logging
22-
from ..auto import CONFIG_MAPPING
21+
from ..auto import CONFIG_MAPPING, AutoConfig
2322

2423

2524
logger = logging.get_logger(__name__)
@@ -64,7 +63,6 @@
6463

6564

6665
class BarkSubModelConfig(PretrainedConfig):
67-
model_type = "bark_module"
6866
keys_to_ignore_at_inference = ["past_key_values"]
6967

7068
attribute_map = {
@@ -101,38 +99,6 @@ def __init__(
10199

102100
super().__init__(**kwargs)
103101

104-
@classmethod
105-
def from_pretrained(
106-
cls,
107-
pretrained_model_name_or_path: Union[str, os.PathLike],
108-
cache_dir: Optional[Union[str, os.PathLike]] = None,
109-
force_download: bool = False,
110-
local_files_only: bool = False,
111-
token: Optional[Union[str, bool]] = None,
112-
revision: str = "main",
113-
**kwargs,
114-
) -> "PretrainedConfig":
115-
kwargs["cache_dir"] = cache_dir
116-
kwargs["force_download"] = force_download
117-
kwargs["local_files_only"] = local_files_only
118-
kwargs["revision"] = revision
119-
120-
cls._set_token_in_kwargs(kwargs, token)
121-
122-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
123-
124-
# get the config dict if we are loading from Bark
125-
if config_dict.get("model_type") == "bark":
126-
config_dict = config_dict[f"{cls.model_type}_config"]
127-
128-
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
129-
logger.warning(
130-
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
131-
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
132-
)
133-
134-
return cls.from_dict(config_dict, **kwargs)
135-
136102

137103
@add_start_docstrings(
138104
BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkSemanticConfig", model="BarkSemanticModel"),
@@ -154,6 +120,7 @@ def from_pretrained(
154120
)
155121
class BarkSemanticConfig(BarkSubModelConfig):
156122
model_type = "semantic"
123+
base_config_key = "semantic_config"
157124

158125

159126
@add_start_docstrings(
@@ -176,6 +143,7 @@ class BarkSemanticConfig(BarkSubModelConfig):
176143
)
177144
class BarkCoarseConfig(BarkSubModelConfig):
178145
model_type = "coarse_acoustics"
146+
base_config_key = "coarse_acoustics_config"
179147

180148

181149
@add_start_docstrings(
@@ -203,6 +171,7 @@ class BarkCoarseConfig(BarkSubModelConfig):
203171
)
204172
class BarkFineConfig(BarkSubModelConfig):
205173
model_type = "fine_acoustics"
174+
base_config_key = "fine_acoustics_config"
206175

207176
def __init__(self, tie_word_embeddings=True, n_codes_total=8, n_codes_given=1, **kwargs):
208177
self.n_codes_total = n_codes_total
@@ -265,6 +234,12 @@ class BarkConfig(PretrainedConfig):
265234
"""
266235

267236
model_type = "bark"
237+
sub_configs = {
238+
"semantic_config": BarkSemanticConfig,
239+
"coarse_acoustics_config": BarkCoarseConfig,
240+
"fine_acoustics_config": BarkFineConfig,
241+
"codec_config": AutoConfig,
242+
}
268243

269244
def __init__(
270245
self,

src/transformers/models/blip/configuration_blip.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
# limitations under the License.
1515
"""Blip model configuration"""
1616

17-
import os
18-
from typing import Union
19-
2017
from ...configuration_utils import PretrainedConfig
2118
from ...utils import logging
2219

@@ -96,6 +93,7 @@ class BlipTextConfig(PretrainedConfig):
9693
```"""
9794

9895
model_type = "blip_text_model"
96+
base_config_key = "text_config"
9997

10098
def __init__(
10199
self,
@@ -146,24 +144,6 @@ def __init__(
146144
self.use_cache = use_cache
147145
self.label_smoothing = label_smoothing
148146

149-
@classmethod
150-
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
151-
cls._set_token_in_kwargs(kwargs)
152-
153-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
154-
155-
# get the text config dict if we are loading from BlipConfig
156-
if config_dict.get("model_type") == "blip":
157-
config_dict = config_dict["text_config"]
158-
159-
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
160-
logger.warning(
161-
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
162-
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
163-
)
164-
165-
return cls.from_dict(config_dict, **kwargs)
166-
167147

168148
class BlipVisionConfig(PretrainedConfig):
169149
r"""
@@ -215,6 +195,7 @@ class BlipVisionConfig(PretrainedConfig):
215195
```"""
216196

217197
model_type = "blip_vision_model"
198+
base_config_key = "vision_config"
218199

219200
def __init__(
220201
self,
@@ -245,24 +226,6 @@ def __init__(
245226
self.layer_norm_eps = layer_norm_eps
246227
self.hidden_act = hidden_act
247228

248-
@classmethod
249-
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
250-
cls._set_token_in_kwargs(kwargs)
251-
252-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
253-
254-
# get the vision config dict if we are loading from BlipConfig
255-
if config_dict.get("model_type") == "blip":
256-
config_dict = config_dict["vision_config"]
257-
258-
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
259-
logger.warning(
260-
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
261-
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
262-
)
263-
264-
return cls.from_dict(config_dict, **kwargs)
265-
266229

267230
class BlipConfig(PretrainedConfig):
268231
r"""
@@ -316,6 +279,7 @@ class BlipConfig(PretrainedConfig):
316279
```"""
317280

318281
model_type = "blip"
282+
sub_configs = {"text_config": BlipTextConfig, "vision_config": BlipVisionConfig}
319283

320284
def __init__(
321285
self,

0 commit comments

Comments
 (0)