-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[Feature] AutoModel can load components using model_index.json #11401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
e506314
85024b0
6a0d0be
d86b0f2
314b6cc
528e002
6e92f40
76ea98d
0e53ad0
f697631
5614a15
f6b6b42
4e5cac1
24f16f6
684384c
0fe68cd
2950372
67e3404
694b81c
3bf51cd
13420fb
af007ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,15 +12,18 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import importlib | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Optional, Union | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from huggingface_hub.utils import validate_hf_hub_args | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..configuration_utils import ConfigMixin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..utils import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger = logging.get_logger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AutoModel(ConfigMixin): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config_name = "config.json" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -153,15 +156,60 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| "token": token, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "local_files_only": local_files_only, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "revision": revision, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "subfolder": subfolder, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| orig_class_name = config["_class_name"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| library = importlib.import_module("diffusers") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| library = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| orig_class_name = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from diffusers import pipelines | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Always attempt to fetch model_index.json first | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cls.config_name = "model_index.json" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if subfolder is not None and subfolder in config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| library, orig_class_name = config[subfolder] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| load_config_kwargs.update({"subfolder": subfolder}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except EnvironmentError as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.debug(e) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Unable to load from model_index.json so fallback to loading from config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if library is None and orig_class_name is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cls.config_name = "config.json" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "_class_name" in config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # If we find a class name in the config, we can try to load the model as a diffusers model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| orig_class_name = config["_class_name"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| library = "diffusers" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| load_config_kwargs.update({"subfolder": subfolder}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # If we don't find a class name in the config, we can try to load the model as a transformers model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Doesn't look like a diffusers model. Loading {pretrained_model_or_path} as a transformer model." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "architectures" in config and len(config["architectures"]) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(config["architectures"]) > 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Found multiple architectures in {pretrained_model_or_path}. Using the first one: {config['architectures'][0]}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| orig_class_name = config["architectures"][0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| library = "transformers" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Couldn't find model associated with the config file at {pretrained_model_or_path}." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | |
| # If we don't find a class name in the config, we can try to load the model as a transformers model | |
| logger.warning( | |
| f"Doesn't look like a diffusers model. Loading {pretrained_model_or_path} as a transformer model." | |
| ) | |
| if "architectures" in config and len(config["architectures"]) > 0: | |
| if len(config["architectures"]) > 1: | |
| logger.warning( | |
| f"Found multiple architectures in {pretrained_model_or_path}. Using the first one: {config['architectures'][0]}" | |
| ) | |
| orig_class_name = config["architectures"][0] | |
| library = "transformers" | |
| load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder}) | |
| else: | |
| raise ValueError( | |
| f"Couldn't find model associated with the config file at {pretrained_model_or_path}." | |
| ) | |
| elif "model_type" in config: | |
| logger.warning( | |
| f"Loading {config[model_type]} as a transformer model from {pretrained_model_or_path}." | |
| ) | |
| from transformers import AutoModel | |
| # we can use the AutoModel from tranformers here I think? | |
| .... | |
| else: | |
| raise ValueError(...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @yiyixuxu, using AutoModel will allow us to only load from this mapping There are many other mappings in the file that we might want to load from hence I am using the architectures.
for example,
AutoModelForCausalLM - helps us import architecture like MarianForCausalLM while AutoModel doesn't
Let me know if we want to ignore all the other mappings and just use AutoModel for import ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @DN6 can you also let me know do you think here?
my opinion is that we should not implement our own logic to load transformer models (that's not part of a diffusers repo), so ok to either dispatch to their AutoModel or throw a warning for not supporting
but open to other thoughts :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree @yiyixuxu. For now let's keep the logic simple. We can refactor later if we need to include things like AutoModelForCausalLM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made the change !
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| import unittest | ||
| from unittest.mock import patch | ||
|
|
||
| from transformers import AlbertForMaskedLM, CLIPTextModel | ||
|
|
||
| from diffusers.models import AutoModel, UNet2DConditionModel | ||
|
|
||
|
|
||
| class TestAutoModel(unittest.TestCase): | ||
| @patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}]) | ||
| def test_load_from_config_diffusers_with_subfolder(self, mock_load_config): | ||
| model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet") | ||
| assert isinstance(model, UNet2DConditionModel) | ||
|
|
||
| @patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"architectures": [ "CLIPTextModel"]}]) | ||
| def test_load_from_config_transformers_with_subfolder(self, mock_load_config): | ||
| model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") | ||
| assert isinstance(model, CLIPTextModel) | ||
|
|
||
| def test_load_from_config_without_subfolder(self): | ||
| model = AutoModel.from_pretrained("hf-internal-testing/tiny-albert") | ||
| assert isinstance(model, AlbertForMaskedLM) | ||
|
|
||
| def test_load_from_model_index(self): | ||
| model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") | ||
| assert isinstance(model, CLIPTextModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would avoid using exceptions for control flow and simplify this a bit