Skip to content

Commit cd13977

Browse files
sayakpaulDN6
andauthored
Apply suggestions from code review
Co-authored-by: Dhruv Nair <[email protected]>
1 parent 04bff1c commit cd13977

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,25 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
339339
return False
340340

341341

342-
def load_shard_file(args):
342+
def load_shard_file(
343+
shard_file,
344+
model,
345+
model_state_dict,
346+
device_map=None,
347+
dtype=None,
348+
hf_quantizer=None,
349+
keep_in_fp32_modules=None,
350+
dduf_entries=None,
351+
loaded_keys=None,
352+
unexpected_keys=None,
353+
offload_index=None,
354+
offload_folder=None,
355+
state_dict_index=None,
356+
state_dict_folder=None,
357+
ignore_mismatched_sizes=False,
358+
low_cpu_mem_usage=False,
359+
):
360+
343361
(
344362
model,
345363
model_state_dict,
@@ -389,7 +407,24 @@ def load_shard_file(args):
389407
return offload_index, state_dict_index, mismatched_keys, error_msgs
390408

391409

392-
def load_shard_files_with_threadpool(args_list):
410+
def _load_shard_files_with_threadpool(
411+
shard_files,
412+
model,
413+
model_state_dict,
414+
device_map=None,
415+
dtype=None,
416+
hf_quantizer=None,
417+
keep_in_fp32_modules=None,
418+
dduf_entries=None,
419+
loaded_keys=None,
420+
unexpected_keys=None,
421+
offload_index=None,
422+
offload_folder=None,
423+
state_dict_index=None,
424+
state_dict_folder=None,
425+
ignore_mismatched_sizes=False,
426+
low_cpu_mem_usage=False,
427+
):
393428
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
394429

395430
# Do not spawn anymore workers than you need

0 commit comments

Comments
 (0)