11import os
22from collections import OrderedDict
33from collections .abc import Iterator
4- from typing import Any
4+ from typing import Any , TypeVar , Union
55from typing_extensions import TypeAlias
66
77from transformers .configuration_utils import PretrainedConfig
88from transformers .tokenization_utils_fast import PreTrainedTokenizerFast
99
10- _LazyAutoMappingValue : TypeAlias = tuple [
11- # Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
12- type [Any | None ],
13- type [PreTrainedTokenizerFast | None ],
14- ]
10+ _T = TypeVar ("_T" )
11+ # Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
12+ _LazyAutoMappingValue : TypeAlias = tuple [type [Any ] | None , type [Any ] | None ]
1513
1614CLASS_DOCSTRING : str
1715FROM_CONFIG_DOCSTRING : str
@@ -26,7 +24,7 @@ class _BaseAutoModelClass:
2624 @classmethod
2725 def from_pretrained (cls , pretrained_model_name_or_path : str | os .PathLike [str ], * model_args , ** kwargs ): ...
2826 @classmethod
29- def register (cls , config_class , model_class ) -> None : ...
27+ def register (cls , config_class , model_class , exist_ok = False ) -> None : ...
3028
3129def insert_head_doc (docstring , head_doc : str = "" ): ...
3230def auto_class_update (cls , checkpoint_for_example : str = "bert-base-cased" , head_doc : str = "" ): ...
@@ -38,10 +36,10 @@ class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue
3836 def __len__ (self ) -> int : ...
3937 def __getitem__ (self , key : type [PretrainedConfig ]) -> _LazyAutoMappingValue : ...
4038 def keys (self ) -> list [type [PretrainedConfig ]]: ...
41- def get (self , key : type [PretrainedConfig ], default : _LazyAutoMappingValue ) -> _LazyAutoMappingValue : ...
39+ def get (self , key : type [PretrainedConfig ], default : _T ) -> _LazyAutoMappingValue | _T : ...
4240 def __bool__ (self ) -> bool : ...
4341 def values (self ) -> list [_LazyAutoMappingValue ]: ...
4442 def items (self ) -> list [tuple [type [PretrainedConfig ], _LazyAutoMappingValue ]]: ...
4543 def __iter__ (self ) -> Iterator [type [PretrainedConfig ]]: ...
46- def __contains__ (self , item : object ) -> bool : ...
47- def register (self , key : type [PretrainedConfig ], value : _LazyAutoMappingValue ) -> None : ...
44+ def __contains__ (self , item : type ) -> bool : ...
45+ def register (self , key : type [PretrainedConfig ], value : _LazyAutoMappingValue , exist_ok = False ) -> None : ...
0 commit comments