Skip to content

Commit e54c540

Browse files
committed
first draft model loading refactor
1 parent aeac0a0 commit e54c540

16 files changed

+625
-647
lines changed

scripts/convert_sana_to_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
SanaPipeline,
1919
SanaTransformer2DModel,
2020
)
21-
from diffusers.models.modeling_utils import load_model_dict_into_meta
21+
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
2222
from diffusers.utils.import_utils import is_accelerate_available
2323

2424

@@ -189,7 +189,7 @@ def main(args):
189189
)
190190

191191
if is_accelerate_available():
192-
load_model_dict_into_meta(transformer, converted_state_dict)
192+
load_state_dict_into_meta_model(transformer, converted_state_dict)
193193
else:
194194
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
195195

scripts/convert_sd3_to_diffusers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from diffusers import AutoencoderKL, SD3Transformer2DModel
99
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
10-
from diffusers.models.modeling_utils import load_model_dict_into_meta
10+
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
1111
from diffusers.utils.import_utils import is_accelerate_available
1212

1313

@@ -319,7 +319,7 @@ def main(args):
319319
dual_attention_layers=attn2_layers,
320320
)
321321
if is_accelerate_available():
322-
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
322+
load_state_dict_into_meta_model(transformer, converted_transformer_state_dict)
323323
else:
324324
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
325325

@@ -339,7 +339,7 @@ def main(args):
339339
)
340340
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
341341
if is_accelerate_available():
342-
load_model_dict_into_meta(vae, converted_vae_state_dict)
342+
load_state_dict_into_meta_model(vae, converted_vae_state_dict)
343343
else:
344344
vae.load_state_dict(converted_vae_state_dict, strict=True)
345345

scripts/convert_stable_audio.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
StableAudioPipeline,
1919
StableAudioProjectionModel,
2020
)
21-
from diffusers.models.modeling_utils import load_model_dict_into_meta
21+
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
2222
from diffusers.utils import is_accelerate_available
2323

2424

@@ -221,7 +221,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
221221
], # assume `seconds_start` and `seconds_total` have the same min / max values.
222222
)
223223
if is_accelerate_available():
224-
load_model_dict_into_meta(projection_model, projection_model_state_dict)
224+
load_state_dict_into_meta_model(projection_model, projection_model_state_dict)
225225
else:
226226
projection_model.load_state_dict(projection_model_state_dict)
227227

@@ -242,7 +242,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
242242
cross_attention_input_dim=model_config["cond_token_dim"],
243243
)
244244
if is_accelerate_available():
245-
load_model_dict_into_meta(model, model_state_dict)
245+
load_state_dict_into_meta_model(model, model_state_dict)
246246
else:
247247
model.load_state_dict(model_state_dict)
248248

@@ -260,7 +260,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
260260
)
261261

262262
if is_accelerate_available():
263-
load_model_dict_into_meta(autoencoder, autoencoder_state_dict)
263+
load_state_dict_into_meta_model(autoencoder, autoencoder_state_dict)
264264
else:
265265
autoencoder.load_state_dict(autoencoder_state_dict)
266266

scripts/convert_stable_cascade.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
2222
from diffusers.models import StableCascadeUNet
23-
from diffusers.models.modeling_utils import load_model_dict_into_meta
23+
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
2424
from diffusers.pipelines.wuerstchen import PaellaVQModel
2525
from diffusers.utils import is_accelerate_available
2626

@@ -126,7 +126,7 @@
126126
switch_level=[False],
127127
)
128128
if is_accelerate_available():
129-
load_model_dict_into_meta(prior_model, prior_state_dict)
129+
load_state_dict_into_meta_model(prior_model, prior_state_dict)
130130
else:
131131
prior_model.load_state_dict(prior_state_dict)
132132

@@ -181,7 +181,7 @@
181181
)
182182

183183
if is_accelerate_available():
184-
load_model_dict_into_meta(decoder, decoder_state_dict)
184+
load_state_dict_into_meta_model(decoder, decoder_state_dict)
185185
else:
186186
decoder.load_state_dict(decoder_state_dict)
187187

scripts/convert_stable_cascade_lite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
2222
from diffusers.models import StableCascadeUNet
23-
from diffusers.models.modeling_utils import load_model_dict_into_meta
23+
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
2424
from diffusers.pipelines.wuerstchen import PaellaVQModel
2525
from diffusers.utils import is_accelerate_available
2626

@@ -133,7 +133,7 @@
133133
)
134134

135135
if is_accelerate_available():
136-
load_model_dict_into_meta(prior_model, prior_state_dict)
136+
load_state_dict_into_meta_model(prior_model, prior_state_dict)
137137
else:
138138
prior_model.load_state_dict(prior_state_dict)
139139

@@ -189,7 +189,7 @@
189189
)
190190

191191
if is_accelerate_available():
192-
load_model_dict_into_meta(decoder, decoder_state_dict)
192+
load_state_dict_into_meta_model(decoder, decoder_state_dict)
193193
else:
194194
decoder.load_state_dict(decoder_state_dict)
195195

src/diffusers/loaders/single_file_model.py

Lines changed: 6 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,11 @@
1313
# limitations under the License.
1414
import importlib
1515
import inspect
16-
import re
17-
from contextlib import nullcontext
1816
from typing import Optional
1917

20-
import torch
2118
from huggingface_hub.utils import validate_hf_hub_args
2219

23-
from ..quantizers import DiffusersAutoQuantizer
24-
from ..utils import deprecate, is_accelerate_available, logging
20+
from ..utils import deprecate, logging
2521
from .single_file_utils import (
2622
SingleFileComponentError,
2723
convert_animatediff_checkpoint_to_diffusers,
@@ -49,12 +45,6 @@
4945
logger = logging.get_logger(__name__)
5046

5147

52-
if is_accelerate_available():
53-
from accelerate import init_empty_weights
54-
55-
from ..models.modeling_utils import load_model_dict_into_meta
56-
57-
5848
SINGLE_FILE_LOADABLE_CLASSES = {
5949
"StableCascadeUNet": {
6050
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
@@ -234,9 +224,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
234224
subfolder = kwargs.pop("subfolder", None)
235225
revision = kwargs.pop("revision", None)
236226
config_revision = kwargs.pop("config_revision", None)
237-
torch_dtype = kwargs.pop("torch_dtype", None)
238-
quantization_config = kwargs.pop("quantization_config", None)
239-
device = kwargs.pop("device", None)
240227
disable_mmap = kwargs.pop("disable_mmap", False)
241228

242229
if isinstance(pretrained_model_link_or_path_or_dict, dict):
@@ -252,12 +239,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
252239
revision=revision,
253240
disable_mmap=disable_mmap,
254241
)
255-
if quantization_config is not None:
256-
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
257-
hf_quantizer.validate_environment()
258-
259-
else:
260-
hf_quantizer = None
261242

262243
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
263244

@@ -336,62 +317,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
336317
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
337318
)
338319

339-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
340-
with ctx():
341-
model = cls.from_config(diffusers_model_config)
342-
343-
# Check if `_keep_in_fp32_modules` is not None
344-
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
345-
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
320+
return cls.from_pretrained(
321+
pretrained_model_name_or_path=None,
322+
state_dict=diffusers_format_checkpoint,
323+
config=diffusers_model_config,
324+
**kwargs,
346325
)
347-
if use_keep_in_fp32_modules:
348-
keep_in_fp32_modules = cls._keep_in_fp32_modules
349-
if not isinstance(keep_in_fp32_modules, list):
350-
keep_in_fp32_modules = [keep_in_fp32_modules]
351-
352-
else:
353-
keep_in_fp32_modules = []
354-
355-
if hf_quantizer is not None:
356-
hf_quantizer.preprocess_model(
357-
model=model,
358-
device_map=None,
359-
state_dict=diffusers_format_checkpoint,
360-
keep_in_fp32_modules=keep_in_fp32_modules,
361-
)
362-
363-
if is_accelerate_available():
364-
param_device = torch.device(device) if device else torch.device("cpu")
365-
named_buffers = model.named_buffers()
366-
unexpected_keys = load_model_dict_into_meta(
367-
model,
368-
diffusers_format_checkpoint,
369-
dtype=torch_dtype,
370-
device=param_device,
371-
hf_quantizer=hf_quantizer,
372-
keep_in_fp32_modules=keep_in_fp32_modules,
373-
named_buffers=named_buffers,
374-
)
375-
376-
else:
377-
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
378-
379-
if model._keys_to_ignore_on_load_unexpected is not None:
380-
for pat in model._keys_to_ignore_on_load_unexpected:
381-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
382-
383-
if len(unexpected_keys) > 0:
384-
logger.warning(
385-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
386-
)
387-
388-
if hf_quantizer is not None:
389-
hf_quantizer.postprocess_model(model)
390-
model.hf_quantizer = hf_quantizer
391-
392-
if torch_dtype is not None and hf_quantizer is None:
393-
model.to(torch_dtype)
394-
395-
model.eval()
396-
397-
return model

src/diffusers/loaders/single_file_utils.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
if is_accelerate_available():
5454
from accelerate import init_empty_weights
5555

56-
from ..models.modeling_utils import load_model_dict_into_meta
56+
from ..models.modeling_utils import load_state_dict_into_meta_model
5757

5858
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5959

@@ -1588,18 +1588,9 @@ def create_diffusers_clip_model_from_ldm(
15881588
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
15891589

15901590
if is_accelerate_available():
1591-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1591+
load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype)
15921592
else:
1593-
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1594-
1595-
if model._keys_to_ignore_on_load_unexpected is not None:
1596-
for pat in model._keys_to_ignore_on_load_unexpected:
1597-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1598-
1599-
if len(unexpected_keys) > 0:
1600-
logger.warning(
1601-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1602-
)
1593+
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16031594

16041595
if torch_dtype is not None:
16051596
model.to(torch_dtype)
@@ -2056,16 +2047,7 @@ def create_diffusers_t5_model_from_checkpoint(
20562047
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
20572048

20582049
if is_accelerate_available():
2059-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2060-
if model._keys_to_ignore_on_load_unexpected is not None:
2061-
for pat in model._keys_to_ignore_on_load_unexpected:
2062-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
2063-
2064-
if len(unexpected_keys) > 0:
2065-
logger.warning(
2066-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
2067-
)
2068-
2050+
load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype)
20692051
else:
20702052
model.load_state_dict(diffusers_format_checkpoint)
20712053

src/diffusers/loaders/transformer_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ImageProjection,
1818
MultiIPAdapterImageProjection,
1919
)
20-
from ..models.modeling_utils import load_model_dict_into_meta
20+
from ..models.modeling_utils import load_state_dict_into_meta_model
2121
from ..utils import (
2222
is_accelerate_available,
2323
is_torch_version,
@@ -82,7 +82,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8282
if not low_cpu_mem_usage:
8383
image_projection.load_state_dict(updated_state_dict, strict=True)
8484
else:
85-
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
85+
load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
8686

8787
return image_projection
8888

@@ -153,7 +153,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
153153
else:
154154
device = self.device
155155
dtype = self.dtype
156-
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
156+
load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype)
157157

158158
key_id += 1
159159

src/diffusers/loaders/transformer_sd3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
1717
from ..models.embeddings import IPAdapterTimeImageProjection
18-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
18+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict_into_meta_model
1919

2020

2121
class SD3Transformer2DLoadersMixin:
@@ -59,7 +59,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
5959
if not low_cpu_mem_usage:
6060
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
6161
else:
62-
load_model_dict_into_meta(
62+
load_state_dict_into_meta_model(
6363
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
6464
)
6565

@@ -86,4 +86,6 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
8686
if not low_cpu_mem_usage:
8787
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
8888
else:
89-
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
89+
load_state_dict_into_meta_model(
90+
self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype
91+
)

src/diffusers/loaders/unet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
IPAdapterPlusImageProjection,
3131
MultiIPAdapterImageProjection,
3232
)
33-
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
33+
from ..models.modeling_utils import load_state_dict, load_state_dict_into_meta_model
3434
from ..utils import (
3535
USE_PEFT_BACKEND,
3636
_get_model_file,
@@ -753,7 +753,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
753753
if not low_cpu_mem_usage:
754754
image_projection.load_state_dict(updated_state_dict, strict=True)
755755
else:
756-
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
756+
load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
757757

758758
return image_projection
759759

@@ -846,7 +846,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
846846
else:
847847
device = next(iter(value_dict.values())).device
848848
dtype = next(iter(value_dict.values())).dtype
849-
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
849+
load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype)
850850

851851
key_id += 2
852852

0 commit comments

Comments
 (0)