Skip to content

Commit 8968e2f

Browse files
committed
up
1 parent cd13977 commit 8968e2f

File tree

4 files changed

+56
-55
lines changed

4 files changed

+56
-55
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import functools
1718
import importlib
1819
import inspect
1920
import math
@@ -32,6 +33,7 @@
3233

3334
from ..quantizers import DiffusersQuantizer
3435
from ..utils import (
36+
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
3537
GGUF_FILE_EXTENSION,
3638
SAFE_WEIGHTS_INDEX_NAME,
3739
SAFETENSORS_FILE_EXTENSION,
@@ -339,7 +341,7 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
339341
return False
340342

341343

342-
def load_shard_file(
344+
def _load_shard_file(
343345
shard_file,
344346
model,
345347
model_state_dict,
@@ -357,25 +359,6 @@ def load_shard_file(
357359
ignore_mismatched_sizes=False,
358360
low_cpu_mem_usage=False,
359361
):
360-
361-
(
362-
model,
363-
model_state_dict,
364-
shard_file,
365-
device_map,
366-
dtype,
367-
hf_quantizer,
368-
keep_in_fp32_modules,
369-
dduf_entries,
370-
loaded_keys,
371-
unexpected_keys,
372-
offload_index,
373-
offload_folder,
374-
state_dict_index,
375-
state_dict_folder,
376-
ignore_mismatched_sizes,
377-
low_cpu_mem_usage,
378-
) = args
379362
assign_to_params_buffers = None
380363
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
381364
mismatched_keys = _find_mismatched_keys(
@@ -425,19 +408,38 @@ def _load_shard_files_with_threadpool(
425408
ignore_mismatched_sizes=False,
426409
low_cpu_mem_usage=False,
427410
):
428-
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
411+
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", str(DEFAULT_HF_PARALLEL_LOADING_WORKERS)))
429412

430413
# Do not spawn anymore workers than you need
431-
num_workers = min(len(args_list), num_workers)
414+
num_workers = min(len(shard_files), num_workers)
432415

433416
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
434417

435418
error_msgs = []
436419
mismatched_keys = []
437420

421+
load_one = functools.partial(
422+
_load_shard_file,
423+
model=model,
424+
model_state_dict=model_state_dict,
425+
device_map=device_map,
426+
dtype=dtype,
427+
hf_quantizer=hf_quantizer,
428+
keep_in_fp32_modules=keep_in_fp32_modules,
429+
dduf_entries=dduf_entries,
430+
loaded_keys=loaded_keys,
431+
unexpected_keys=unexpected_keys,
432+
offload_index=offload_index,
433+
offload_folder=offload_folder,
434+
state_dict_index=state_dict_index,
435+
state_dict_folder=state_dict_folder,
436+
ignore_mismatched_sizes=ignore_mismatched_sizes,
437+
low_cpu_mem_usage=low_cpu_mem_usage,
438+
)
439+
438440
with ThreadPoolExecutor(max_workers=num_workers) as executor:
439-
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
440-
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
441+
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
442+
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
441443
for future in as_completed(futures):
442444
result = future.result()
443445
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result

src/diffusers/models/modeling_utils.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import copy
18+
import functools
1819
import inspect
1920
import itertools
2021
import json
@@ -70,8 +71,8 @@
7071
_expand_device_map,
7172
_fetch_index_file,
7273
_fetch_index_file_legacy,
73-
load_shard_file,
74-
load_shard_files_with_threadpool,
74+
_load_shard_file,
75+
_load_shard_files_with_threadpool,
7576
load_state_dict,
7677
)
7778

@@ -1547,41 +1548,37 @@ def _load_pretrained_model(
15471548
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
15481549
resolved_model_file = [state_dict]
15491550

1550-
# prepare the arguments.
1551-
args_list = [
1552-
(
1553-
model,
1554-
model_state_dict,
1555-
shard_file,
1556-
device_map,
1557-
dtype,
1558-
hf_quantizer,
1559-
keep_in_fp32_modules,
1560-
dduf_entries,
1561-
loaded_keys,
1562-
unexpected_keys,
1563-
offload_index,
1564-
offload_folder,
1565-
state_dict_index,
1566-
state_dict_folder,
1567-
ignore_mismatched_sizes,
1568-
low_cpu_mem_usage,
1569-
)
1570-
for shard_file in resolved_model_file
1571-
]
1551+
# Prepare the loading function sharing the attributes shared between them.
1552+
load_fn = functools.partial(
1553+
_load_shard_files_with_threadpool if is_parallel_loading_enabled else _load_shard_file,
1554+
model=model,
1555+
model_state_dict=model_state_dict,
1556+
device_map=device_map,
1557+
dtype=dtype,
1558+
hf_quantizer=hf_quantizer,
1559+
keep_in_fp32_modules=keep_in_fp32_modules,
1560+
dduf_entries=dduf_entries,
1561+
loaded_keys=loaded_keys,
1562+
unexpected_keys=unexpected_keys,
1563+
offload_index=offload_index,
1564+
offload_folder=offload_folder,
1565+
state_dict_index=state_dict_index,
1566+
state_dict_folder=state_dict_folder,
1567+
ignore_mismatched_sizes=ignore_mismatched_sizes,
1568+
low_cpu_mem_usage=low_cpu_mem_usage,
1569+
)
15721570

15731571
if is_parallel_loading_enabled:
1574-
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool(
1575-
args_list
1576-
)
1572+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(resolved_model_file)
15771573
error_msgs += _error_msgs
15781574
mismatched_keys += _mismatched_keys
15791575
else:
1580-
if len(args_list) > 1:
1581-
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
1576+
shard_files = resolved_model_file
1577+
if len(resolved_model_file) > 1:
1578+
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
15821579

1583-
for args in args_list:
1584-
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_file(args)
1580+
for shard_file in shard_files:
1581+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
15851582
error_msgs += _error_msgs
15861583
mismatched_keys += _mismatched_keys
15871584

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .. import __version__
2121
from .constants import (
2222
CONFIG_NAME,
23+
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
2324
DEPRECATED_REVISION_ARGS,
2425
DIFFUSERS_DYNAMIC_MODULE_NAME,
2526
FLAX_WEIGHTS_NAME,

src/diffusers/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
DIFFUSERS_REQUEST_TIMEOUT = 60
4444
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
4545
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
46+
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4647

4748
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
4849
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

0 commit comments

Comments
 (0)