|
15 | 15 | Type, |
16 | 16 | TypeVar, |
17 | 17 | Union, |
18 | | - get_args, |
19 | 18 | ) |
20 | 19 |
|
21 | 20 | from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE |
@@ -326,12 +325,11 @@ def __new__(cls, *args, **kwargs) -> "ModelHubMixin": |
326 | 325 | if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder |
327 | 326 | }, |
328 | 327 | } |
329 | | - init_config.pop("config", {}) |
| 328 | + passed_config = init_config.pop("config", {}) |
330 | 329 |
|
331 | 330 | # Populate `init_config` with provided config |
332 | | - provided_config = passed_values.get("config") |
333 | | - if isinstance(provided_config, dict): |
334 | | - init_config.update(provided_config) |
| 331 | + if isinstance(passed_config, dict): |
| 332 | + init_config.update(passed_config) |
335 | 333 |
|
336 | 334 | # Set `config` attribute and return |
337 | 335 | if init_config != {}: |
@@ -362,9 +360,14 @@ def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T |
362 | 360 | if value is None: |
363 | 361 | return None |
364 | 362 | expected_type = unwrap_simple_optional_type(expected_type) |
| 363 | + # Dataclass => handle it |
| 364 | + if is_dataclass(expected_type): |
| 365 | + return _load_dataclass(expected_type, value) # type: ignore[return-value] |
| 366 | + # Otherwise => check custom decoders |
365 | 367 | for type_, (_, decoder) in cls._hub_mixin_coders.items(): |
366 | 368 | if inspect.isclass(expected_type) and issubclass(expected_type, type_): |
367 | 369 | return decoder(value) |
| 370 | + # Otherwise => don't decode |
368 | 371 | return value |
369 | 372 |
|
370 | 373 | def save_pretrained( |
@@ -531,18 +534,9 @@ def from_pretrained( |
531 | 534 |
|
532 | 535 | # Check if `config` argument was passed at init |
533 | 536 | if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs: |
534 | | - # Check if `config` argument is a dataclass |
| 537 | + # Decode `config` argument if it was passed |
535 | 538 | config_annotation = cls._hub_mixin_init_parameters["config"].annotation |
536 | | - if config_annotation is inspect.Parameter.empty: |
537 | | - pass # no annotation |
538 | | - elif is_dataclass(config_annotation): |
539 | | - config = _load_dataclass(config_annotation, config) |
540 | | - else: |
541 | | - # if Optional/Union annotation => check if a dataclass is in the Union |
542 | | - for _sub_annotation in get_args(config_annotation): |
543 | | - if is_dataclass(_sub_annotation): |
544 | | - config = _load_dataclass(_sub_annotation, config) |
545 | | - break |
| 539 | + config = cls._decode_arg(config_annotation, config) |
546 | 540 |
|
547 | 541 | # Forward config to model initialization |
548 | 542 | model_kwargs["config"] = config |
|
0 commit comments