-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[Feature] AutoModel can load components using model_index.json #11401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
e506314
85024b0
6a0d0be
d86b0f2
314b6cc
528e002
6e92f40
76ea98d
0e53ad0
f697631
5614a15
f6b6b42
4e5cac1
24f16f6
684384c
0fe68cd
2950372
67e3404
694b81c
3bf51cd
13420fb
af007ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
|
|
||
|
|
||
| class AutoModel(ConfigMixin): | ||
| config_name = "config.json" | ||
| config_name = "model_index.json" | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| raise EnvironmentError( | ||
|
|
@@ -156,10 +156,17 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi | |
| "subfolder": subfolder, | ||
| } | ||
|
|
||
| config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would avoid using exceptions for control flow and simplify this a bit load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"token": token,
"local_files_only": local_files_only,
"revision": revision,
}
library = None
orig_class_name = None
from diffusers import pipelines
# Always attempt to fetch model_index.json first
try:
cls.config_name = "model_index.json"
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
if subfolder is not None and subfolder in config:
library, orig_class_name = config[subfolder]
except EntryNotFoundError as e:
logger.debug(e)
# Unable to load from model_index.json so fallback to loading from config
if library is None and orig_class_name is None:
cls.config_name = "config.json"
load_config_kwargs.update({"subfolder": subfolder})
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
library = "diffusers"
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=pipelines,
is_pipeline_module=hasattr(pipelines, library),
) |
||
| orig_class_name = config["_class_name"] | ||
|
|
||
| library = importlib.import_module("diffusers") | ||
| try: | ||
| mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"} | ||
|
||
| config = cls.load_config(os.path.join(pretrained_model_or_path), **mindex_kwargs) | ||
| library, orig_class_name = config[subfolder] | ||
| library = importlib.import_module(library) | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| except Exception: | ||
|
||
| # Fallback to loading the config from the config.json file | ||
| cls.config_name = "config.json" | ||
| config = cls.load_config(os.path.join(pretrained_model_or_path), **load_config_kwargs) | ||
| library = importlib.import_module("diffusers") | ||
| orig_class_name = config["_class_name"] | ||
|
|
||
| model_cls = getattr(library, orig_class_name, None) | ||
| if model_cls is None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.