|
1 | 1 | import inspect |
2 | 2 | import json |
3 | 3 | import os |
4 | | -from dataclasses import asdict, dataclass, is_dataclass |
| 4 | +from dataclasses import Field, asdict, dataclass, is_dataclass |
5 | 5 | from pathlib import Path |
6 | | -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union |
| 6 | +from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union |
7 | 7 |
|
8 | 8 | import packaging.version |
9 | 9 |
|
|
24 | 24 | ) |
25 | 25 |
|
26 | 26 |
|
27 | | -if TYPE_CHECKING: |
28 | | - from _typeshed import DataclassInstance |
29 | | - |
30 | 27 | if is_torch_available(): |
31 | 28 | import torch # type: ignore |
32 | 29 |
|
|
38 | 35 |
|
39 | 36 | logger = logging.get_logger(__name__) |
40 | 37 |
|
| 38 | + |
| 39 | +# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349 |
| 40 | +class DataclassInstance(Protocol): |
| 41 | + __dataclass_fields__: ClassVar[Dict[str, Field]] |
| 42 | + |
| 43 | + |
41 | 44 | # Generic variable that is either ModelHubMixin or a subclass thereof |
42 | 45 | T = TypeVar("T", bound="ModelHubMixin") |
43 | 46 | # Generic variable to represent an args type |
@@ -175,7 +178,7 @@ class ModelHubMixin: |
175 | 178 | ``` |
176 | 179 | """ |
177 | 180 |
|
178 | | - _hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None |
| 181 | + _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None |
179 | 182 | # ^ optional config attribute automatically set in `from_pretrained` |
180 | 183 | _hub_mixin_info: MixinInfo |
181 | 184 | # ^ information about the library integrating ModelHubMixin (used to generate model card) |
@@ -366,7 +369,7 @@ def save_pretrained( |
366 | 369 | self, |
367 | 370 | save_directory: Union[str, Path], |
368 | 371 | *, |
369 | | - config: Optional[Union[dict, "DataclassInstance"]] = None, |
| 372 | + config: Optional[Union[dict, DataclassInstance]] = None, |
370 | 373 | repo_id: Optional[str] = None, |
371 | 374 | push_to_hub: bool = False, |
372 | 375 | model_card_kwargs: Optional[Dict[str, Any]] = None, |
@@ -618,7 +621,7 @@ def push_to_hub( |
618 | 621 | self, |
619 | 622 | repo_id: str, |
620 | 623 | *, |
621 | | - config: Optional[Union[dict, "DataclassInstance"]] = None, |
| 624 | + config: Optional[Union[dict, DataclassInstance]] = None, |
622 | 625 | commit_message: str = "Push model using huggingface_hub.", |
623 | 626 | private: Optional[bool] = None, |
624 | 627 | token: Optional[str] = None, |
@@ -825,7 +828,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric |
825 | 828 | return model |
826 | 829 |
|
827 | 830 |
|
828 | | -def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance": |
| 831 | +def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance: |
829 | 832 | """Load a dataclass instance from a dictionary. |
830 | 833 |
|
831 | 834 | Fields not expected by the dataclass are ignored. |
|
0 commit comments