Skip to content

Commit ac420af

Browse files
committed
fix logic
1 parent 290b88d commit ac420af

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def load_state_dict(
136136
checkpoint_file: Union[str, os.PathLike],
137137
variant: Optional[str] = None,
138138
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
139-
disable_mmap: bool = False,
140139
):
141140
"""
142141
Reads a checkpoint file, returning properly formatted errors if they arise.
@@ -152,8 +151,6 @@ def load_state_dict(
152151
# tensors are loaded on cpu
153152
with dduf_entries[checkpoint_file].as_mmap() as mm:
154153
return safetensors.torch.load(mm)
155-
if disable_mmap:
156-
return safetensors.torch.load(open(checkpoint_file, "rb").read())
157154
else:
158155
return safetensors.torch.load_file(checkpoint_file, device="cpu")
159156
elif file_extension == GGUF_FILE_EXTENSION:
@@ -345,8 +342,14 @@ def _merge_sharded_checkpoints(
345342
# Load tensors from each unique file
346343
for file_name in files_to_load:
347344
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
348-
if not os.path.exists(part_file_path) or (dduf_entries and part_file_path not in dduf_entries):
349-
raise FileNotFoundError(f"Part file {file_name} not found.")
345+
if dduf_entries:
346+
# If dduf_entries is provided, check if part_file_path is in it
347+
if part_file_path not in dduf_entries:
348+
raise FileNotFoundError(f"Part file {file_name} not found.")
349+
else:
350+
# If dduf_entries is not provided, check if the file exists on disk
351+
if not os.path.exists(part_file_path):
352+
raise FileNotFoundError(f"Part file {file_name} not found.")
350353

351354
if is_safetensors:
352355
if dduf_entries:

src/diffusers/utils/hub_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,14 @@ def _get_checkpoint_shard_files(
449449
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
450450
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
451451
"""
452-
if not os.path.isfile(index_filename) and (dduf_entries and index_filename not in dduf_entries):
453-
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
452+
if dduf_entries:
453+
# If dduf_entries is provided, check if part_file_path is in it
454+
if index_filename not in dduf_entries:
455+
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
456+
else:
457+
# If dduf_entries is not provided, check if the file exists on disk
458+
if not os.path.isfile(index_filename):
459+
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
454460

455461
if dduf_entries:
456462
index = json.loads(dduf_entries[index_filename].read_text())

0 commit comments

Comments
 (0)