Skip to content

Commit 80de641

Browse files
authored
Allow Automodel to support custom model code (huggingface#12353)
* update * update
1 parent 76810ec commit 80de641

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

src/diffusers/models/auto_model.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from ..configuration_utils import ConfigMixin
2121
from ..utils import logging
22+
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
2223

2324

2425
logger = logging.get_logger(__name__)
@@ -114,6 +115,8 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
114115
disable_mmap ('bool', *optional*, defaults to 'False'):
115116
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
116117
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
118+
trust_remote_cocde (`bool`, *optional*, defaults to `False`):
119+
Whether to trust remote code
117120
118121
<Tip>
119122
@@ -140,22 +143,22 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
140143
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
141144
```
142145
"""
143-
cache_dir = kwargs.pop("cache_dir", None)
144-
force_download = kwargs.pop("force_download", False)
145-
proxies = kwargs.pop("proxies", None)
146-
token = kwargs.pop("token", None)
147-
local_files_only = kwargs.pop("local_files_only", False)
148-
revision = kwargs.pop("revision", None)
149146
subfolder = kwargs.pop("subfolder", None)
150-
151-
load_config_kwargs = {
152-
"cache_dir": cache_dir,
153-
"force_download": force_download,
154-
"proxies": proxies,
155-
"token": token,
156-
"local_files_only": local_files_only,
157-
"revision": revision,
158-
}
147+
trust_remote_code = kwargs.pop("trust_remote_code", False)
148+
149+
hub_kwargs_names = [
150+
"cache_dir",
151+
"force_download",
152+
"local_files_only",
153+
"proxies",
154+
"resume_download",
155+
"revision",
156+
"token",
157+
]
158+
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
159+
160+
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
161+
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
159162

160163
library = None
161164
orig_class_name = None
@@ -189,15 +192,35 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
189192
else:
190193
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
191194

192-
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
193-
194-
model_cls, _ = get_class_obj_and_candidates(
195-
library_name=library,
196-
class_name=orig_class_name,
197-
importable_classes=ALL_IMPORTABLE_CLASSES,
198-
pipelines=None,
199-
is_pipeline_module=False,
200-
)
195+
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
196+
trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
197+
if not (has_remote_code and trust_remote_code):
198+
raise ValueError(
199+
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
200+
)
201+
202+
if has_remote_code and trust_remote_code:
203+
class_ref = config["auto_map"][cls.__name__]
204+
module_file, class_name = class_ref.split(".")
205+
module_file = module_file + ".py"
206+
model_cls = get_class_from_dynamic_module(
207+
pretrained_model_or_path,
208+
subfolder=subfolder,
209+
module_file=module_file,
210+
class_name=class_name,
211+
**hub_kwargs,
212+
**kwargs,
213+
)
214+
else:
215+
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
216+
217+
model_cls, _ = get_class_obj_and_candidates(
218+
library_name=library,
219+
class_name=orig_class_name,
220+
importable_classes=ALL_IMPORTABLE_CLASSES,
221+
pipelines=None,
222+
is_pipeline_module=False,
223+
)
201224

202225
if model_cls is None:
203226
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def find_pipeline_class(loaded_module):
247247
def get_cached_module_file(
248248
pretrained_model_name_or_path: Union[str, os.PathLike],
249249
module_file: str,
250+
subfolder: Optional[str] = None,
250251
cache_dir: Optional[Union[str, os.PathLike]] = None,
251252
force_download: bool = False,
252253
proxies: Optional[Dict[str, str]] = None,
@@ -353,6 +354,7 @@ def get_cached_module_file(
353354
resolved_module_file = hf_hub_download(
354355
pretrained_model_name_or_path,
355356
module_file,
357+
subfolder=subfolder,
356358
cache_dir=cache_dir,
357359
force_download=force_download,
358360
proxies=proxies,
@@ -410,6 +412,7 @@ def get_cached_module_file(
410412
get_cached_module_file(
411413
pretrained_model_name_or_path,
412414
f"{module_needed}.py",
415+
subfolder=subfolder,
413416
cache_dir=cache_dir,
414417
force_download=force_download,
415418
proxies=proxies,
@@ -424,6 +427,7 @@ def get_cached_module_file(
424427
def get_class_from_dynamic_module(
425428
pretrained_model_name_or_path: Union[str, os.PathLike],
426429
module_file: str,
430+
subfolder: Optional[str] = None,
427431
class_name: Optional[str] = None,
428432
cache_dir: Optional[Union[str, os.PathLike]] = None,
429433
force_download: bool = False,
@@ -497,6 +501,7 @@ def get_class_from_dynamic_module(
497501
final_module = get_cached_module_file(
498502
pretrained_model_name_or_path,
499503
module_file,
504+
subfolder=subfolder,
500505
cache_dir=cache_dir,
501506
force_download=force_download,
502507
proxies=proxies,

0 commit comments

Comments
 (0)