Skip to content

Commit f3a4ddc

Browse files
committed
sharded checkpoint compat
1 parent da48dcb commit f3a4ddc

File tree

3 files changed

+31
-14
lines changed

3 files changed

+31
-14
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ def _fetch_index_file(
315315
commit_hash=commit_hash,
316316
dduf_entries=dduf_entries,
317317
)
318-
index_file = Path(index_file)
318+
if not dduf_entries:
319+
index_file = Path(index_file)
319320
except (EntryNotFoundError, EnvironmentError):
320321
index_file = None
321322

@@ -324,7 +325,9 @@ def _fetch_index_file(
324325

325326
# Adapted from
326327
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
327-
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
328+
def _merge_sharded_checkpoints(
329+
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
330+
):
328331
weight_map = sharded_metadata.get("weight_map", None)
329332
if weight_map is None:
330333
raise KeyError("'weight_map' key not found in the shard index file.")
@@ -337,14 +340,19 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
337340
# Load tensors from each unique file
338341
for file_name in files_to_load:
339342
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
340-
if not os.path.exists(part_file_path):
343+
if not os.path.exists(part_file_path) and (dduf_entries and part_file_path not in dduf_entries):
341344
raise FileNotFoundError(f"Part file {file_name} not found.")
342345

343346
if is_safetensors:
344-
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
345-
for tensor_key in f.keys():
346-
if tensor_key in weight_map:
347-
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
347+
if dduf_entries:
348+
with dduf_entries[part_file_path].as_mmap() as mm:
349+
tensors = safetensors.torch.load(mm)
350+
merged_state_dict.update(tensors)
351+
else:
352+
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
353+
for tensor_key in f.keys():
354+
if tensor_key in weight_map:
355+
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
348356
else:
349357
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
350358

src/diffusers/models/modeling_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
782782
# this becomes applicable when the variant is not None.
783783
if variant is not None and (index_file is None or not os.path.exists(index_file)):
784784
index_file = _fetch_index_file_legacy(**index_file_kwargs)
785-
if index_file is not None and index_file.is_file():
785+
if index_file is not None and (dduf_entries or index_file.is_file()):
786786
is_sharded = True
787787

788788
if is_sharded and from_flax:
@@ -812,7 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
812812
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
813813
else:
814814
# in the case it is sharded, we have already the index
815-
if is_sharded and not dduf_entries:
815+
if is_sharded:
816816
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
817817
pretrained_model_name_or_path,
818818
index_file,
@@ -823,9 +823,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
823823
user_agent=user_agent,
824824
revision=revision,
825825
subfolder=subfolder or "",
826+
dduf_entries=dduf_entries,
826827
)
827-
if hf_quantizer is not None and is_bnb_quantization_method:
828-
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
828+
if (hf_quantizer is not None and is_bnb_quantization_method) or dduf_entries:
829+
model_file = _merge_sharded_checkpoints(
830+
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries
831+
)
829832
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
830833
is_sharded = False
831834

src/diffusers/utils/hub_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def _get_checkpoint_shard_files(
437437
user_agent=None,
438438
revision=None,
439439
subfolder="",
440+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
440441
):
441442
"""
442443
For a given model:
@@ -448,11 +449,14 @@ def _get_checkpoint_shard_files(
448449
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
449450
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
450451
"""
451-
if not os.path.isfile(index_filename):
452+
if not os.path.isfile(index_filename) and (dduf_entries and index_filename not in dduf_entries):
452453
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
453454

454-
with open(index_filename, "r") as f:
455-
index = json.loads(f.read())
455+
if dduf_entries:
456+
index = json.loads(dduf_entries[index_filename].read_text())
457+
else:
458+
with open(index_filename, "r") as f:
459+
index = json.loads(f.read())
456460

457461
original_shard_filenames = sorted(set(index["weight_map"].values()))
458462
sharded_metadata = index["metadata"]
@@ -466,6 +470,8 @@ def _get_checkpoint_shard_files(
466470
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
467471
)
468472
return shards_path, sharded_metadata
473+
elif dduf_entries:
474+
return shards_path, sharded_metadata
469475

470476
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
471477
allow_patterns = original_shard_filenames

0 commit comments

Comments
 (0)