Skip to content

Commit dca6388

Browse files
committed
up
1 parent 8968e2f commit dca6388

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ def _load_shard_file(
359359
ignore_mismatched_sizes=False,
360360
low_cpu_mem_usage=False,
361361
):
362-
assign_to_params_buffers = None
363362
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
364363
mismatched_keys = _find_mismatched_keys(
365364
state_dict,
@@ -383,8 +382,7 @@ def _load_shard_file(
383382
state_dict_folder=state_dict_folder,
384383
)
385384
else:
386-
if assign_to_params_buffers is None:
387-
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
385+
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
388386

389387
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
390388
return offload_index, state_dict_index, mismatched_keys, error_msgs
@@ -408,9 +406,8 @@ def _load_shard_files_with_threadpool(
408406
ignore_mismatched_sizes=False,
409407
low_cpu_mem_usage=False,
410408
):
411-
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", str(DEFAULT_HF_PARALLEL_LOADING_WORKERS)))
412-
413409
# Do not spawn anymore workers than you need
410+
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", str(DEFAULT_HF_PARALLEL_LOADING_WORKERS)))
414411
num_workers = min(len(shard_files), num_workers)
415412

416413
logger.info(f"Loading model weights in parallel with {num_workers} workers...")

tests/models/test_modeling_common.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,39 @@ def test_sharded_checkpoints_with_variant(self):
14281428

14291429
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
14301430

1431+
@require_torch_accelerator
1432+
def test_sharded_checkpoints_with_parallel_loading(self):
1433+
torch.manual_seed(0)
1434+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1435+
model = self.model_class(**config).eval()
1436+
model = model.to(torch_device)
1437+
1438+
base_output = model(**inputs_dict)
1439+
1440+
model_size = compute_module_persistent_sizes(model)[""]
1441+
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
1442+
with tempfile.TemporaryDirectory() as tmp_dir:
1443+
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
1444+
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
1445+
1446+
# Now check if the right number of shards exists. First, let's get the number of shards.
1447+
# Since this number can be dependent on the model being tested, it's important that we calculate it
1448+
# instead of hardcoding it.
1449+
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
1450+
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
1451+
self.assertTrue(actual_num_shards == expected_num_shards)
1452+
1453+
# Load with parallel loading
1454+
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes"
1455+
new_model = self.model_class.from_pretrained(tmp_dir).eval()
1456+
new_model = new_model.to(torch_device)
1457+
1458+
torch.manual_seed(0)
1459+
if "generator" in inputs_dict:
1460+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1461+
new_output = new_model(**inputs_dict)
1462+
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
1463+
14311464
@require_torch_accelerator
14321465
def test_sharded_checkpoints_device_map(self):
14331466
if self.model_class._no_split_modules is None:

0 commit comments

Comments
 (0)