Skip to content

Commit 72b6259

Browse files
committed
fix dduf
1 parent 17c1be2 commit 72b6259

File tree

3 files changed

+5
-22
lines changed

3 files changed

+5
-22
lines changed

src/diffusers/loaders/unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
IPAdapterPlusImageProjection,
3131
MultiIPAdapterImageProjection,
3232
)
33-
from ..models.modeling_utils import load_state_dict, load_model_dict_into_meta
33+
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
3434
from ..utils import (
3535
USE_PEFT_BACKEND,
3636
_get_model_file,

src/diffusers/models/model_loading_utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -307,25 +307,6 @@ def load_model_dict_into_meta(
307307
return error_msgs, offload_index, state_dict_index
308308

309309

310-
def load_model_dict_into_meta(
311-
model,
312-
state_dict: OrderedDict,
313-
dtype: Optional[Union[str, torch.dtype]] = None,
314-
model_name_or_path: Optional[str] = None,
315-
hf_quantizer=None,
316-
keep_in_fp32_modules=None,
317-
device_map=None,
318-
unexpected_keys=None,
319-
is_safetensors=None,
320-
offload_folder=None,
321-
offload_index=None,
322-
state_dict_index=None,
323-
state_dict_folder=None,
324-
) -> List[str]:
325-
error_msgs = []
326-
return error_msgs, offload_index, state_dict_index
327-
328-
329310
def _load_state_dict_into_model(
330311
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
331312
) -> List[str]:

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@
6464
_fetch_index_file,
6565
_fetch_index_file_legacy,
6666
_load_state_dict_into_model,
67-
load_state_dict,
6867
load_model_dict_into_meta,
68+
load_state_dict,
6969
)
7070

7171

@@ -1033,6 +1033,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10331033
dtype=torch_dtype,
10341034
hf_quantizer=hf_quantizer,
10351035
keep_in_fp32_modules=keep_in_fp32_modules,
1036+
dduf_entries=dduf_entries,
10361037
)
10371038
loading_info = {
10381039
"missing_keys": missing_keys,
@@ -1156,6 +1157,7 @@ def _load_pretrained_model(
11561157
device_map=None,
11571158
offload_state_dict=None,
11581159
offload_folder=None,
1160+
dduf_entries=None,
11591161
):
11601162
model_state_dict = model.state_dict()
11611163
expected_keys = list(model_state_dict.keys())
@@ -1209,7 +1211,7 @@ def _load_pretrained_model(
12091211
if len(resolved_archive_file) > 1:
12101212
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
12111213
for shard_file in resolved_archive_file:
1212-
state_dict = load_state_dict(shard_file)
1214+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
12131215
model._fix_state_dict_keys_on_load(state_dict)
12141216

12151217
def _find_mismatched_keys(

0 commit comments

Comments
 (0)