Skip to content

Commit c49a855

Browse files
authored
Merge branch 'main' into test-better-torch-compile
2 parents d669340 + 437cb36 commit c49a855

File tree

7 files changed

+271
-7
lines changed

7 files changed

+271
-7
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
"AutoencoderKLWan",
156156
"AutoencoderOobleck",
157157
"AutoencoderTiny",
158+
"AutoModel",
158159
"CacheMixin",
159160
"CogVideoXTransformer3DModel",
160161
"CogView3PlusTransformer2DModel",
@@ -197,6 +198,7 @@
197198
"T2IAdapter",
198199
"T5FilmDecoder",
199200
"Transformer2DModel",
201+
"TransformerTemporalModel",
200202
"UNet1DModel",
201203
"UNet2DConditionModel",
202204
"UNet2DModel",
@@ -731,6 +733,7 @@
731733
AutoencoderKLWan,
732734
AutoencoderOobleck,
733735
AutoencoderTiny,
736+
AutoModel,
734737
CacheMixin,
735738
CogVideoXTransformer3DModel,
736739
CogView3PlusTransformer2DModel,
@@ -772,6 +775,7 @@
772775
T2IAdapter,
773776
T5FilmDecoder,
774777
Transformer2DModel,
778+
TransformerTemporalModel,
775779
UNet1DModel,
776780
UNet2DConditionModel,
777781
UNet2DModel,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
4242
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
4343
_import_structure["autoencoders.vq_model"] = ["VQModel"]
44+
_import_structure["auto_model"] = ["AutoModel"]
4445
_import_structure["cache_utils"] = ["CacheMixin"]
4546
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
4647
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
@@ -103,6 +104,7 @@
103104
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
104105
if is_torch_available():
105106
from .adapter import MultiAdapter, T2IAdapter
107+
from .auto_model import AutoModel
106108
from .autoencoders import (
107109
AsymmetricAutoencoderKL,
108110
AutoencoderDC,

src/diffusers/models/auto_model.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import importlib
16+
import os
17+
from typing import Optional, Union
18+
19+
from huggingface_hub.utils import validate_hf_hub_args
20+
21+
from ..configuration_utils import ConfigMixin
22+
23+
24+
class AutoModel(ConfigMixin):
25+
config_name = "config.json"
26+
27+
def __init__(self, *args, **kwargs):
28+
raise EnvironmentError(
29+
f"{self.__class__.__name__} is designed to be instantiated "
30+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
31+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
32+
)
33+
34+
@classmethod
35+
@validate_hf_hub_args
36+
def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
37+
r"""
38+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
39+
40+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
41+
train the model, set it back in training mode with `model.train()`.
42+
43+
Parameters:
44+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
45+
Can be either:
46+
47+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
48+
the Hub.
49+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
50+
with [`~ModelMixin.save_pretrained`].
51+
52+
cache_dir (`Union[str, os.PathLike]`, *optional*):
53+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
54+
is not used.
55+
torch_dtype (`str` or `torch.dtype`, *optional*):
56+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
57+
dtype is automatically derived from the model's weights.
58+
force_download (`bool`, *optional*, defaults to `False`):
59+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
60+
cached versions if they exist.
61+
proxies (`Dict[str, str]`, *optional*):
62+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
63+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
64+
output_loading_info (`bool`, *optional*, defaults to `False`):
65+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
66+
local_files_only(`bool`, *optional*, defaults to `False`):
67+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
68+
won't be downloaded from the Hub.
69+
token (`str` or *bool*, *optional*):
70+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
71+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
72+
revision (`str`, *optional*, defaults to `"main"`):
73+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
74+
allowed by Git.
75+
from_flax (`bool`, *optional*, defaults to `False`):
76+
Load the model weights from a Flax checkpoint save file.
77+
subfolder (`str`, *optional*, defaults to `""`):
78+
The subfolder location of a model file within a larger model repository on the Hub or locally.
79+
mirror (`str`, *optional*):
80+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
81+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
82+
information.
83+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
84+
A map that specifies where each submodule should go. It doesn't need to be defined for each
85+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
86+
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
87+
88+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
89+
more information about each option see [designing a device
90+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
91+
max_memory (`Dict`, *optional*):
92+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
93+
each GPU and the available CPU RAM if unset.
94+
offload_folder (`str` or `os.PathLike`, *optional*):
95+
The path to offload weights if `device_map` contains the value `"disk"`.
96+
offload_state_dict (`bool`, *optional*):
97+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
98+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
99+
when there is some disk offload.
100+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
101+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
102+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
103+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
104+
argument to `True` will raise an error.
105+
variant (`str`, *optional*):
106+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
107+
loading `from_flax`.
108+
use_safetensors (`bool`, *optional*, defaults to `None`):
109+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
110+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
111+
weights. If set to `False`, `safetensors` weights are not loaded.
112+
disable_mmap ('bool', *optional*, defaults to 'False'):
113+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
114+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
115+
116+
<Tip>
117+
118+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
119+
`huggingface-cli login`. You can also activate the special
120+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
121+
firewalled environment.
122+
123+
</Tip>
124+
125+
Example:
126+
127+
```py
128+
from diffusers import AutoModel
129+
130+
unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
131+
```
132+
133+
If you get the error message below, you need to finetune the weights for your downstream task:
134+
135+
```bash
136+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
137+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
138+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
139+
```
140+
"""
141+
cache_dir = kwargs.pop("cache_dir", None)
142+
force_download = kwargs.pop("force_download", False)
143+
proxies = kwargs.pop("proxies", None)
144+
token = kwargs.pop("token", None)
145+
local_files_only = kwargs.pop("local_files_only", False)
146+
revision = kwargs.pop("revision", None)
147+
subfolder = kwargs.pop("subfolder", None)
148+
149+
load_config_kwargs = {
150+
"cache_dir": cache_dir,
151+
"force_download": force_download,
152+
"proxies": proxies,
153+
"token": token,
154+
"local_files_only": local_files_only,
155+
"revision": revision,
156+
"subfolder": subfolder,
157+
}
158+
159+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
160+
orig_class_name = config["_class_name"]
161+
162+
library = importlib.import_module("diffusers")
163+
164+
model_cls = getattr(library, orig_class_name, None)
165+
if model_cls is None:
166+
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
167+
168+
kwargs = {**load_config_kwargs, **kwargs}
169+
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from transformers import (
2121
ClapFeatureExtractor,
2222
ClapModel,
23-
GPT2Model,
23+
GPT2LMHeadModel,
2424
RobertaTokenizer,
2525
RobertaTokenizerFast,
2626
SpeechT5HifiGan,
@@ -196,7 +196,7 @@ def __init__(
196196
text_encoder: ClapModel,
197197
text_encoder_2: Union[T5EncoderModel, VitsModel],
198198
projection_model: AudioLDM2ProjectionModel,
199-
language_model: GPT2Model,
199+
language_model: GPT2LMHeadModel,
200200
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
201201
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
202202
feature_extractor: ClapFeatureExtractor,
@@ -259,7 +259,10 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
259259
)
260260

261261
device_type = torch_device.type
262-
device = torch.device(f"{device_type}:{gpu_id or torch_device.index}")
262+
device_str = device_type
263+
if gpu_id or torch_device.index:
264+
device_str = f"{device_str}:{gpu_id or torch_device.index}"
265+
device = torch.device(device_str)
263266

264267
if self.device.type != "cpu":
265268
self.to("cpu", silence_dtype_warnings=True)
@@ -316,9 +319,9 @@ def generate_language_model(
316319
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
317320

318321
# forward pass to get next hidden states
319-
output = self.language_model(**model_inputs, return_dict=True)
322+
output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
320323

321-
next_hidden_states = output.last_hidden_state
324+
next_hidden_states = output.hidden_states[-1]
322325

323326
# Update the model input
324327
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,21 @@ def from_pretrained(cls, *args, **kwargs):
280280
requires_backends(cls, ["torch"])
281281

282282

283+
class AutoModel(metaclass=DummyObject):
284+
_backends = ["torch"]
285+
286+
def __init__(self, *args, **kwargs):
287+
requires_backends(self, ["torch"])
288+
289+
@classmethod
290+
def from_config(cls, *args, **kwargs):
291+
requires_backends(cls, ["torch"])
292+
293+
@classmethod
294+
def from_pretrained(cls, *args, **kwargs):
295+
requires_backends(cls, ["torch"])
296+
297+
283298
class CacheMixin(metaclass=DummyObject):
284299
_backends = ["torch"]
285300

@@ -895,6 +910,21 @@ def from_pretrained(cls, *args, **kwargs):
895910
requires_backends(cls, ["torch"])
896911

897912

913+
class TransformerTemporalModel(metaclass=DummyObject):
914+
_backends = ["torch"]
915+
916+
def __init__(self, *args, **kwargs):
917+
requires_backends(self, ["torch"])
918+
919+
@classmethod
920+
def from_config(cls, *args, **kwargs):
921+
requires_backends(cls, ["torch"])
922+
923+
@classmethod
924+
def from_pretrained(cls, *args, **kwargs):
925+
requires_backends(cls, ["torch"])
926+
927+
898928
class UNet1DModel(metaclass=DummyObject):
899929
_backends = ["torch"]
900930

tests/models/test_modeling_common.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
AttnProcessorNPU,
4646
XFormersAttnProcessor,
4747
)
48+
from diffusers.models.auto_model import AutoModel
4849
from diffusers.training_utils import EMAModel
4950
from diffusers.utils import (
5051
SAFE_WEIGHTS_INDEX_NAME,
@@ -1577,6 +1578,49 @@ def run_forward(model):
15771578
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
15781579
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
15791580

1581+
def test_auto_model(self, expected_max_diff=5e-5):
1582+
if self.forward_requires_fresh_args:
1583+
model = self.model_class(**self.init_dict)
1584+
else:
1585+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1586+
model = self.model_class(**init_dict)
1587+
1588+
model = model.eval()
1589+
model = model.to(torch_device)
1590+
1591+
if hasattr(model, "set_default_attn_processor"):
1592+
model.set_default_attn_processor()
1593+
1594+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
1595+
model.save_pretrained(tmpdirname, safe_serialization=False)
1596+
1597+
auto_model = AutoModel.from_pretrained(tmpdirname)
1598+
if hasattr(auto_model, "set_default_attn_processor"):
1599+
auto_model.set_default_attn_processor()
1600+
1601+
auto_model = auto_model.eval()
1602+
auto_model = auto_model.to(torch_device)
1603+
1604+
with torch.no_grad():
1605+
if self.forward_requires_fresh_args:
1606+
output_original = model(**self.inputs_dict(0))
1607+
output_auto = auto_model(**self.inputs_dict(0))
1608+
else:
1609+
output_original = model(**inputs_dict)
1610+
output_auto = auto_model(**inputs_dict)
1611+
1612+
if isinstance(output_original, dict):
1613+
output_original = output_original.to_tuple()[0]
1614+
if isinstance(output_auto, dict):
1615+
output_auto = output_auto.to_tuple()[0]
1616+
1617+
max_diff = (output_original - output_auto).abs().max().item()
1618+
self.assertLessEqual(
1619+
max_diff,
1620+
expected_max_diff,
1621+
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}",
1622+
)
1623+
15801624

15811625
@is_staging_test
15821626
class ModelPushToHubTester(unittest.TestCase):

tests/pipelines/audioldm2/test_audioldm2.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ClapModel,
2727
ClapTextConfig,
2828
GPT2Config,
29-
GPT2Model,
29+
GPT2LMHeadModel,
3030
RobertaTokenizer,
3131
SpeechT5HifiGan,
3232
SpeechT5HifiGanConfig,
@@ -162,7 +162,7 @@ def get_dummy_components(self):
162162
n_ctx=99,
163163
n_positions=99,
164164
)
165-
language_model = GPT2Model(language_model_config)
165+
language_model = GPT2LMHeadModel(language_model_config)
166166
language_model.config.max_new_tokens = 8
167167

168168
torch.manual_seed(0)
@@ -516,6 +516,18 @@ def test_sequential_cpu_offload_forward_pass(self):
516516
def test_encode_prompt_works_in_isolation(self):
517517
pass
518518

519+
@unittest.skip("Not supported yet due to CLAPModel.")
520+
def test_sequential_offload_forward_pass_twice(self):
521+
pass
522+
523+
@unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.")
524+
def test_cpu_offload_forward_pass_twice(self):
525+
pass
526+
527+
@unittest.skip("Not supported yet. `vocoder` is not offloaded.")
528+
def test_model_cpu_offload_forward_pass(self):
529+
pass
530+
519531

520532
@nightly
521533
class AudioLDM2PipelineSlowTests(unittest.TestCase):

0 commit comments

Comments
 (0)