Skip to content

Commit af72ece

Browse files
committed
checking.
1 parent 9f4d997 commit af72ece

File tree

2 files changed

+167
-82
lines changed

2 files changed

+167
-82
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
from array import array
2121
from collections import OrderedDict
22+
from concurrent.futures import ThreadPoolExecutor, as_completed
2223
from pathlib import Path
2324
from typing import Dict, List, Optional, Union
2425
from zipfile import is_zipfile
@@ -304,6 +305,130 @@ def load_model_dict_into_meta(
304305
return offload_index, state_dict_index
305306

306307

308+
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
309+
"""
310+
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
311+
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
312+
parameters.
313+
314+
"""
315+
if model_to_load.device.type == "meta":
316+
return False
317+
318+
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
319+
return False
320+
321+
# Some models explicitly do not support param buffer assignment
322+
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
323+
logger.debug(
324+
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
325+
)
326+
return False
327+
328+
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
329+
first_key = next(iter(model_to_load.state_dict().keys()))
330+
if start_prefix + first_key in state_dict:
331+
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
332+
333+
return False
334+
335+
336+
def load_shard_file(args):
337+
(
338+
model,
339+
model_state_dict,
340+
shard_file,
341+
device_map,
342+
dtype,
343+
hf_quantizer,
344+
keep_in_fp32_modules,
345+
dduf_entries,
346+
loaded_keys,
347+
unexpected_keys,
348+
offload_index,
349+
offload_folder,
350+
state_dict_index,
351+
state_dict_folder,
352+
ignore_mismatched_sizes,
353+
low_cpu_mem_usage,
354+
) = args
355+
assign_to_params_buffers = None
356+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
357+
mismatched_keys = _find_mismatched_keys(
358+
state_dict,
359+
model_state_dict,
360+
loaded_keys,
361+
ignore_mismatched_sizes,
362+
)
363+
error_msgs = []
364+
if low_cpu_mem_usage:
365+
offload_index, state_dict_index = load_model_dict_into_meta(
366+
model,
367+
state_dict,
368+
device_map=device_map,
369+
dtype=dtype,
370+
hf_quantizer=hf_quantizer,
371+
keep_in_fp32_modules=keep_in_fp32_modules,
372+
unexpected_keys=unexpected_keys,
373+
offload_folder=offload_folder,
374+
offload_index=offload_index,
375+
state_dict_index=state_dict_index,
376+
state_dict_folder=state_dict_folder,
377+
)
378+
else:
379+
if assign_to_params_buffers is None:
380+
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
381+
382+
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
383+
return offload_index, state_dict_index, mismatched_keys, error_msgs
384+
385+
386+
def load_shard_files_with_threadpool(args_list):
387+
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
388+
389+
# Do not spawn anymore workers than you need
390+
num_workers = min(len(args_list), num_workers)
391+
392+
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
393+
394+
error_msgs = []
395+
mismatched_keys = []
396+
397+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
398+
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
399+
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
400+
for future in as_completed(futures):
401+
result = future.result()
402+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
403+
error_msgs += _error_msgs
404+
mismatched_keys += _mismatched_keys
405+
pbar.update(1)
406+
407+
return offload_index, state_dict_index, mismatched_keys, error_msgs
408+
409+
410+
def _find_mismatched_keys(
411+
state_dict,
412+
model_state_dict,
413+
loaded_keys,
414+
ignore_mismatched_sizes,
415+
):
416+
mismatched_keys = []
417+
if ignore_mismatched_sizes:
418+
for checkpoint_key in loaded_keys:
419+
model_key = checkpoint_key
420+
# If the checkpoint is sharded, we may not have the key here.
421+
if checkpoint_key not in state_dict:
422+
continue
423+
424+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
425+
mismatched_keys.append(
426+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
427+
)
428+
del state_dict[checkpoint_key]
429+
return mismatched_keys
430+
431+
307432
def _load_state_dict_into_model(
308433
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
309434
) -> List[str]:

src/diffusers/models/modeling_utils.py

Lines changed: 42 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@
6666
_determine_device_map,
6767
_fetch_index_file,
6868
_fetch_index_file_legacy,
69-
_load_state_dict_into_model,
70-
load_model_dict_into_meta,
69+
load_shard_file,
70+
load_shard_files_with_threadpool,
7171
load_state_dict,
7272
)
7373

@@ -200,34 +200,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
200200
return last_tuple[1].dtype
201201

202202

203-
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
204-
"""
205-
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
206-
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
207-
parameters.
208-
209-
"""
210-
if model_to_load.device.type == "meta":
211-
return False
212-
213-
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
214-
return False
215-
216-
# Some models explicitly do not support param buffer assignment
217-
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
218-
logger.debug(
219-
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
220-
)
221-
return False
222-
223-
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
224-
first_key = next(iter(model_to_load.state_dict().keys()))
225-
if start_prefix + first_key in state_dict:
226-
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
227-
228-
return False
229-
230-
231203
@contextmanager
232204
def no_init_weights():
233205
"""
@@ -926,6 +898,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
926898
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
927899
disable_mmap = kwargs.pop("disable_mmap", False)
928900

901+
# TODO: enable TRUE ENV VARs
902+
is_parallel_loading_enabled = bool(os.environ.get("HF_ENABLE_PARALLEL_LOADING", 1))
903+
904+
if is_parallel_loading_enabled and not low_cpu_mem_usage:
905+
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
906+
929907
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
930908
torch_dtype = torch.float32
931909
logger.warning(
@@ -1261,6 +1239,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12611239
hf_quantizer=hf_quantizer,
12621240
keep_in_fp32_modules=keep_in_fp32_modules,
12631241
dduf_entries=dduf_entries,
1242+
is_parallel_loading_enabled=is_parallel_loading_enabled,
12641243
)
12651244
loading_info = {
12661245
"missing_keys": missing_keys,
@@ -1456,6 +1435,7 @@ def _load_pretrained_model(
14561435
offload_state_dict: Optional[bool] = None,
14571436
offload_folder: Optional[Union[str, os.PathLike]] = None,
14581437
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
1438+
is_parallel_loading_enabled: Optional[bool] = False,
14591439
):
14601440
model_state_dict = model.state_dict()
14611441
expected_keys = list(model_state_dict.keys())
@@ -1470,8 +1450,6 @@ def _load_pretrained_model(
14701450
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
14711451

14721452
mismatched_keys = []
1473-
1474-
assign_to_params_buffers = None
14751453
error_msgs = []
14761454

14771455
# Deal with offload
@@ -1499,63 +1477,45 @@ def _load_pretrained_model(
14991477
# load_state_dict will manage the case where we pass a dict instead of a file
15001478
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
15011479
resolved_model_file = [state_dict]
1480+
is_file = not isinstance(state_dict, dict)
15021481

1503-
if len(resolved_model_file) > 1:
1504-
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
1505-
1506-
for shard_file in resolved_model_file:
1507-
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
1508-
1509-
def _find_mismatched_keys(
1510-
state_dict,
1511-
model_state_dict,
1512-
loaded_keys,
1513-
ignore_mismatched_sizes,
1514-
):
1515-
mismatched_keys = []
1516-
if ignore_mismatched_sizes:
1517-
for checkpoint_key in loaded_keys:
1518-
model_key = checkpoint_key
1519-
# If the checkpoint is sharded, we may not have the key here.
1520-
if checkpoint_key not in state_dict:
1521-
continue
1522-
1523-
if (
1524-
model_key in model_state_dict
1525-
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1526-
):
1527-
mismatched_keys.append(
1528-
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1529-
)
1530-
del state_dict[checkpoint_key]
1531-
return mismatched_keys
1532-
1533-
mismatched_keys += _find_mismatched_keys(
1534-
state_dict,
1482+
# prepare the arguments.
1483+
args_list = [
1484+
(
1485+
model,
15351486
model_state_dict,
1487+
shard_file,
1488+
device_map,
1489+
dtype,
1490+
hf_quantizer,
1491+
keep_in_fp32_modules,
1492+
dduf_entries,
15361493
loaded_keys,
1494+
unexpected_keys,
1495+
offload_index,
1496+
offload_folder,
1497+
state_dict_index,
1498+
state_dict_folder,
15371499
ignore_mismatched_sizes,
1500+
low_cpu_mem_usage,
15381501
)
1502+
for shard_file in resolved_model_file
1503+
]
15391504

1540-
if low_cpu_mem_usage:
1541-
offload_index, state_dict_index = load_model_dict_into_meta(
1542-
model,
1543-
state_dict,
1544-
device_map=device_map,
1545-
dtype=dtype,
1546-
hf_quantizer=hf_quantizer,
1547-
keep_in_fp32_modules=keep_in_fp32_modules,
1548-
unexpected_keys=unexpected_keys,
1549-
offload_folder=offload_folder,
1550-
offload_index=offload_index,
1551-
state_dict_index=state_dict_index,
1552-
state_dict_folder=state_dict_folder,
1553-
)
1554-
else:
1555-
if assign_to_params_buffers is None:
1556-
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
1505+
if is_parallel_loading_enabled and is_file:
1506+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool(
1507+
args_list
1508+
)
1509+
error_msgs += _error_msgs
1510+
mismatched_keys += _mismatched_keys
1511+
else:
1512+
if len(args_list) > 1:
1513+
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
15571514

1558-
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
1515+
for args in args_list:
1516+
offload_index, state_dict_index, _error_msgs = load_shard_file(args)
1517+
error_msgs += _error_msgs
1518+
mismatched_keys += _mismatched_keys
15591519

15601520
if offload_index is not None and len(offload_index) > 0:
15611521
save_offload_index(offload_index, offload_folder)

0 commit comments

Comments
 (0)