|
42 | 42 | is_torch_version, |
43 | 43 | logging, |
44 | 44 | ) |
45 | | -from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card |
| 45 | +from ..utils.hub_utils import ( |
| 46 | + PushToHubMixin, |
| 47 | + load_or_create_model_card, |
| 48 | + populate_model_card, |
| 49 | +) |
46 | 50 | from .model_loading_utils import ( |
47 | 51 | _determine_device_map, |
48 | 52 | _load_state_dict_into_model, |
@@ -1039,3 +1043,55 @@ def recursive_find_attn_block(module) -> None: |
1039 | 1043 | del module.key |
1040 | 1044 | del module.value |
1041 | 1045 | del module.proj_attn |
| 1046 | + |
| 1047 | + |
| 1048 | +class LegacyModelMixin(ModelMixin): |
| 1049 | + r""" |
| 1050 | + A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more |
| 1051 | + pipeline-specific classes (like `DiTTransformer2DModel`). |
| 1052 | + """ |
| 1053 | + |
| 1054 | + @classmethod |
| 1055 | + @validate_hf_hub_args |
| 1056 | + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): |
| 1057 | + # To prevent depedency import problem. |
| 1058 | + from .model_loading_utils import _fetch_remapped_cls_from_config |
| 1059 | + |
| 1060 | + cache_dir = kwargs.pop("cache_dir", None) |
| 1061 | + force_download = kwargs.pop("force_download", False) |
| 1062 | + resume_download = kwargs.pop("resume_download", None) |
| 1063 | + proxies = kwargs.pop("proxies", None) |
| 1064 | + local_files_only = kwargs.pop("local_files_only", None) |
| 1065 | + token = kwargs.pop("token", None) |
| 1066 | + revision = kwargs.pop("revision", None) |
| 1067 | + subfolder = kwargs.pop("subfolder", None) |
| 1068 | + |
| 1069 | + # Load config if we don't provide a configuration |
| 1070 | + config_path = pretrained_model_name_or_path |
| 1071 | + |
| 1072 | + user_agent = { |
| 1073 | + "diffusers": __version__, |
| 1074 | + "file_type": "model", |
| 1075 | + "framework": "pytorch", |
| 1076 | + } |
| 1077 | + |
| 1078 | + # load config |
| 1079 | + config, _, _ = cls.load_config( |
| 1080 | + config_path, |
| 1081 | + cache_dir=cache_dir, |
| 1082 | + return_unused_kwargs=True, |
| 1083 | + return_commit_hash=True, |
| 1084 | + force_download=force_download, |
| 1085 | + resume_download=resume_download, |
| 1086 | + proxies=proxies, |
| 1087 | + local_files_only=local_files_only, |
| 1088 | + token=token, |
| 1089 | + revision=revision, |
| 1090 | + subfolder=subfolder, |
| 1091 | + user_agent=user_agent, |
| 1092 | + **kwargs, |
| 1093 | + ) |
| 1094 | + # resolve remapping |
| 1095 | + remapped_class = _fetch_remapped_cls_from_config(config, cls) |
| 1096 | + |
| 1097 | + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs) |
0 commit comments