Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
"T2IAdapter",
"T5FilmDecoder",
"Transformer2DModel",
"TransformerTemporalModel",
"UNet1DModel",
"UNet2DConditionModel",
"UNet2DModel",
Expand Down Expand Up @@ -649,6 +650,7 @@
T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
Expand Down
55 changes: 55 additions & 0 deletions src/diffusers/models/auto_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import importlib
import os
from typing import Optional, Union

from huggingface_hub.utils import validate_hf_hub_args

from ..configuration_utils import ConfigMixin
from ..utils import CONFIG_NAME
from .modeling_utils import ModelMixin


class AutoModel(ConfigMixin):
config_name = CONFIG_NAME

"""TODO"""

def __init__(self, *args, **kwargs):
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
"from_config(config) methods."
)

@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
"""TODO"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)

load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"token": token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
}
config = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
class_name = config["_class_name"]
diffusers_module = importlib.import_module(__name__.split(".")[0])

try:
model_cls: ModelMixin = getattr(diffusers_module, class_name)
except Exception:
raise ValueError(f"Could not import the `{class_name}` class from diffusers.")

kwargs = {**load_config_kwargs, **kwargs}
return model_cls.from_pretrained(pretrained_model_name_or_path, **kwargs)
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class TransformerTemporalModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class UNet1DModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
41 changes: 41 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
AttnProcessorNPU,
XFormersAttnProcessor,
)
from diffusers.models.auto_model import AutoModel
from diffusers.training_utils import EMAModel
from diffusers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
Expand Down Expand Up @@ -1424,6 +1425,46 @@ def get_memory_usage(storage_dtype, compute_dtype):
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)

@parameterized.expand([None, "foo"])
def test_works_with_automodel(self, subfolder):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
has_generator_in_inputs = False
if "generator" in inputs_dict:
has_generator_in_inputs = True
inputs_dict["generator"] = torch.manual_seed(0)

model = self.model_class(**config).eval()
model_cls_name = model.__class__.__name__
model.to(torch_device)

torch.manual_seed(0)
output = model(**inputs_dict, return_dict=False)[0]

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, subfolder) if subfolder else tmpdir
model.save_pretrained(path)
automodel = AutoModel.from_pretrained(tmpdir, subfolder=subfolder).eval()
automodel.to(torch_device)

automodel_cls_name = automodel.__class__.__name__
self.assertTrue(model_cls_name == automodel_cls_name)
for p1, p2 in zip(model.parameters(), automodel.parameters()):
if not (torch.isnan(p1).any() and torch.isnan(p2).any()):
self.assertTrue(torch.equal(p1, p2))

torch.manual_seed(0)
if has_generator_in_inputs:
inputs_dict["generator"] = torch.manual_seed(0)
output_automodel = model(**inputs_dict, return_dict=False)[0]

self.assertTrue(torch.allclose(output[0], output_automodel[0], atol=1e-5))

def test_automodel_raises_error_with_direct_init(self):
config, _ = self.prepare_init_args_and_inputs_for_common()
with self.assertRaises(EnvironmentError) as err_context:
_ = AutoModel(**config)
self.assertTrue("is designed to be instantiated" in str(err_context.exception))


@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
Expand Down