|  | 
|  | 1 | +import importlib | 
|  | 2 | +import os | 
|  | 3 | +from typing import Optional, Union | 
|  | 4 | + | 
|  | 5 | +from huggingface_hub.utils import validate_hf_hub_args | 
|  | 6 | + | 
|  | 7 | +from ..configuration_utils import ConfigMixin | 
|  | 8 | +from ..utils import CONFIG_NAME | 
|  | 9 | +from .modeling_utils import ModelMixin | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +class AutoModel(ConfigMixin): | 
|  | 13 | +    config_name = CONFIG_NAME | 
|  | 14 | + | 
|  | 15 | +    """TODO""" | 
|  | 16 | + | 
|  | 17 | +    def __init__(self, *args, **kwargs): | 
|  | 18 | +        raise EnvironmentError( | 
|  | 19 | +            f"{self.__class__.__name__} is designed to be instantiated " | 
|  | 20 | +            f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " | 
|  | 21 | +            "from_config(config) methods." | 
|  | 22 | +        ) | 
|  | 23 | + | 
|  | 24 | +    @classmethod | 
|  | 25 | +    @validate_hf_hub_args | 
|  | 26 | +    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): | 
|  | 27 | +        """TODO""" | 
|  | 28 | +        cache_dir = kwargs.pop("cache_dir", None) | 
|  | 29 | +        force_download = kwargs.pop("force_download", False) | 
|  | 30 | +        proxies = kwargs.pop("proxies", None) | 
|  | 31 | +        token = kwargs.pop("token", None) | 
|  | 32 | +        local_files_only = kwargs.pop("local_files_only", False) | 
|  | 33 | +        revision = kwargs.pop("revision", None) | 
|  | 34 | +        subfolder = kwargs.pop("subfolder", None) | 
|  | 35 | + | 
|  | 36 | +        load_config_kwargs = { | 
|  | 37 | +            "cache_dir": cache_dir, | 
|  | 38 | +            "force_download": force_download, | 
|  | 39 | +            "proxies": proxies, | 
|  | 40 | +            "token": token, | 
|  | 41 | +            "local_files_only": local_files_only, | 
|  | 42 | +            "revision": revision, | 
|  | 43 | +            "subfolder": subfolder, | 
|  | 44 | +        } | 
|  | 45 | +        config = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) | 
|  | 46 | +        class_name = config["_class_name"] | 
|  | 47 | +        diffusers_module = importlib.import_module(__name__.split(".")[0]) | 
|  | 48 | + | 
|  | 49 | +        try: | 
|  | 50 | +            model_cls: ModelMixin = getattr(diffusers_module, class_name) | 
|  | 51 | +        except Exception: | 
|  | 52 | +            raise ValueError(f"Could not import the `{class_name}` class from diffusers.") | 
|  | 53 | + | 
|  | 54 | +        kwargs = {**load_config_kwargs, **kwargs} | 
|  | 55 | +        return model_cls.from_pretrained(pretrained_model_name_or_path, **kwargs) | 
0 commit comments