Skip to content

Commit 81900b2

Browse files
authored
Merge branch 'main' into lumina2-lora
2 parents a8fb6ac + f5929e0 commit 81900b2

File tree

12 files changed

+844
-515
lines changed

12 files changed

+844
-515
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353

5454
if is_accelerate_available():
55-
from accelerate import init_empty_weights
55+
from accelerate import dispatch_model, init_empty_weights
5656

5757
from ..models.modeling_utils import load_model_dict_into_meta
5858

@@ -366,19 +366,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
366366
keep_in_fp32_modules=keep_in_fp32_modules,
367367
)
368368

369+
device_map = None
369370
if is_accelerate_available():
370371
param_device = torch.device(device) if device else torch.device("cpu")
371-
named_buffers = model.named_buffers()
372-
unexpected_keys = load_model_dict_into_meta(
372+
empty_state_dict = model.state_dict()
373+
unexpected_keys = [
374+
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
375+
]
376+
device_map = {"": param_device}
377+
load_model_dict_into_meta(
373378
model,
374379
diffusers_format_checkpoint,
375380
dtype=torch_dtype,
376-
device=param_device,
381+
device_map=device_map,
377382
hf_quantizer=hf_quantizer,
378383
keep_in_fp32_modules=keep_in_fp32_modules,
379-
named_buffers=named_buffers,
384+
unexpected_keys=unexpected_keys,
380385
)
381-
382386
else:
383387
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
384388

@@ -400,4 +404,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
400404

401405
model.eval()
402406

407+
if device_map is not None:
408+
device_map_kwargs = {"device_map": device_map}
409+
dispatch_model(model, **device_map_kwargs)
410+
403411
return model

src/diffusers/loaders/single_file_utils.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
15931593
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
15941594

15951595
if is_accelerate_available():
1596-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1596+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
15971597
else:
1598-
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1599-
1600-
if model._keys_to_ignore_on_load_unexpected is not None:
1601-
for pat in model._keys_to_ignore_on_load_unexpected:
1602-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1603-
1604-
if len(unexpected_keys) > 0:
1605-
logger.warning(
1606-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1607-
)
1598+
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16081599

16091600
if torch_dtype is not None:
16101601
model.to(torch_dtype)
@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
20612052
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
20622053

20632054
if is_accelerate_available():
2064-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2065-
if model._keys_to_ignore_on_load_unexpected is not None:
2066-
for pat in model._keys_to_ignore_on_load_unexpected:
2067-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
2068-
2069-
if len(unexpected_keys) > 0:
2070-
logger.warning(
2071-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
2072-
)
2073-
2055+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
20742056
else:
20752057
model.load_state_dict(diffusers_format_checkpoint)
20762058

0 commit comments

Comments
 (0)