Skip to content

Commit b982eca

Browse files
committed
feat: implement automodel for diffusers.
1 parent c4d4ac2 commit b982eca

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

src/diffusers/models/auto_model.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import importlib
2+
import os
3+
from typing import Optional, Union
4+
5+
from huggingface_hub.utils import validate_hf_hub_args
6+
7+
from ..configuration_utils import ConfigMixin
8+
from ..utils import CONFIG_NAME
9+
from .modeling_utils import ModelMixin
10+
11+
12+
class AutoModel(ConfigMixin):
13+
config_name = CONFIG_NAME
14+
15+
"""TODO"""
16+
17+
def __init__(self, *args, **kwargs):
18+
raise EnvironmentError(
19+
f"{self.__class__.__name__} is designed to be instantiated "
20+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
21+
"from_config(config) methods."
22+
)
23+
24+
@classmethod
25+
@validate_hf_hub_args
26+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
27+
"""TODO"""
28+
cache_dir = kwargs.pop("cache_dir", None)
29+
force_download = kwargs.pop("force_download", False)
30+
proxies = kwargs.pop("proxies", None)
31+
token = kwargs.pop("token", None)
32+
local_files_only = kwargs.pop("local_files_only", False)
33+
revision = kwargs.pop("revision", None)
34+
subfolder = kwargs.pop("subfolder", None)
35+
36+
load_config_kwargs = {
37+
"cache_dir": cache_dir,
38+
"force_download": force_download,
39+
"proxies": proxies,
40+
"token": token,
41+
"local_files_only": local_files_only,
42+
"revision": revision,
43+
"subfolder": subfolder,
44+
}
45+
config = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
46+
class_name = config["_class_name"]
47+
diffusers_module = importlib.import_module(__name__.split(".")[0])
48+
49+
try:
50+
model_cls: ModelMixin = getattr(diffusers_module, class_name)
51+
except Exception:
52+
raise ValueError(f"Could not import the `{class_name}` class from diffusers.")
53+
54+
kwargs = {**load_config_kwargs, **kwargs}
55+
return model_cls.from_pretrained(pretrained_model_name_or_path, **kwargs)

tests/models/test_modeling_common.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
AttnProcessorNPU,
4545
XFormersAttnProcessor,
4646
)
47+
from diffusers.models.auto_model import AutoModel
4748
from diffusers.training_utils import EMAModel
4849
from diffusers.utils import (
4950
SAFE_WEIGHTS_INDEX_NAME,
@@ -1424,6 +1425,37 @@ def get_memory_usage(storage_dtype, compute_dtype):
14241425
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
14251426
)
14261427

1428+
@parameterized.expand([None, "foo"])
1429+
def test_works_with_automodel(self, subfolder):
1430+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1431+
model = self.model_class(**config).eval()
1432+
model_cls_name = model.__class__.__name__
1433+
model.to(torch_device)
1434+
1435+
torch.manual_seed(0)
1436+
output = model(**inputs_dict, return_dict=False)[0]
1437+
1438+
with tempfile.TemporaryDirectory() as tmpdir:
1439+
path = os.path.join(tmpdir, subfolder) if subfolder else tmpdir
1440+
model.save_pretrained(path)
1441+
automodel = AutoModel.from_pretrained(tmpdir, subfolder=subfolder).to(torch_device)
1442+
1443+
automodel_cls_name = automodel.__class__.__name__
1444+
self.assertTrue(model_cls_name == automodel_cls_name)
1445+
for p1, p2 in zip(model.parameters(), automodel.parameters()):
1446+
self.assertTrue(torch.equal(p1, p2))
1447+
1448+
torch.manual_seed(0)
1449+
output_automodel = model(**inputs_dict, return_dict=False)[0]
1450+
1451+
self.assertTrue(torch.allclose(output[0], output_automodel[0], atol=1e-5))
1452+
1453+
def test_automodel_raises_error_with_direct_init(self):
1454+
config, _ = self.prepare_init_args_and_inputs_for_common()
1455+
with self.assertRaises(EnvironmentError) as err_context:
1456+
_ = AutoModel(**config)
1457+
self.assertTrue("is designed to be instantiated" in str(err_context.exception))
1458+
14271459

14281460
@is_staging_test
14291461
class ModelPushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)