Skip to content

Commit d1c4a61

Browse files
committed
rename resolved_archive_file to resolved_model_file
1 parent abd3a91 commit d1c4a61

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10291029
keep_in_fp32_modules = []
10301030

10311031
is_sharded = False
1032-
resolved_archive_file = None
1032+
resolved_model_file = None
10331033

10341034
# Determine if we're loading from a directory of sharded checkpoints.
10351035
sharded_metadata = None
@@ -1064,7 +1064,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10641064

10651065
# load model
10661066
if from_flax:
1067-
resolved_archive_file = _get_model_file(
1067+
resolved_model_file = _get_model_file(
10681068
pretrained_model_name_or_path,
10691069
weights_name=FLAX_WEIGHTS_NAME,
10701070
cache_dir=cache_dir,
@@ -1082,11 +1082,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10821082
# Convert the weights
10831083
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
10841084

1085-
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
1085+
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
10861086
else:
10871087
# in the case it is sharded, we have already the index
10881088
if is_sharded:
1089-
resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files(
1089+
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
10901090
pretrained_model_name_or_path,
10911091
index_file,
10921092
cache_dir=cache_dir,
@@ -1100,7 +1100,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11001100
)
11011101
elif use_safetensors:
11021102
try:
1103-
resolved_archive_file = _get_model_file(
1103+
resolved_model_file = _get_model_file(
11041104
pretrained_model_name_or_path,
11051105
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
11061106
cache_dir=cache_dir,
@@ -1123,8 +1123,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11231123
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
11241124
)
11251125

1126-
if resolved_archive_file is None and not is_sharded:
1127-
resolved_archive_file = _get_model_file(
1126+
if resolved_model_file is None and not is_sharded:
1127+
resolved_model_file = _get_model_file(
11281128
pretrained_model_name_or_path,
11291129
weights_name=_add_variant(WEIGHTS_NAME, variant),
11301130
cache_dir=cache_dir,
@@ -1139,8 +1139,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11391139
dduf_entries=dduf_entries,
11401140
)
11411141

1142-
if not isinstance(resolved_archive_file, list):
1143-
resolved_archive_file = [resolved_archive_file]
1142+
if not isinstance(resolved_model_file, list):
1143+
resolved_model_file = [resolved_model_file]
11441144

11451145
# set dtype to instantiate the model under:
11461146
# 1. If torch_dtype is not None, we use that dtype
@@ -1168,7 +1168,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11681168
if not is_sharded:
11691169
# Time to load the checkpoint
11701170
state_dict = load_state_dict(
1171-
resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries
1171+
resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries
11721172
)
11731173
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
11741174
model._fix_state_dict_keys_on_load(state_dict)
@@ -1200,7 +1200,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12001200
) = cls._load_pretrained_model(
12011201
model,
12021202
state_dict,
1203-
resolved_archive_file,
1203+
resolved_model_file,
12041204
pretrained_model_name_or_path,
12051205
loaded_keys,
12061206
ignore_mismatched_sizes=ignore_mismatched_sizes,
@@ -1361,7 +1361,7 @@ def _load_pretrained_model(
13611361
cls,
13621362
model,
13631363
state_dict: OrderedDict,
1364-
resolved_archive_file: List[str],
1364+
resolved_model_file: List[str],
13651365
pretrained_model_name_or_path: Union[str, os.PathLike],
13661366
loaded_keys: List[str],
13671367
ignore_mismatched_sizes: bool = False,
@@ -1415,13 +1415,13 @@ def _load_pretrained_model(
14151415

14161416
if state_dict is not None:
14171417
# load_state_dict will manage the case where we pass a dict instead of a file
1418-
# if state dict is not None, it means that we don't need to read the files from resolved_archive_file also
1419-
resolved_archive_file = [state_dict]
1418+
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
1419+
resolved_model_file = [state_dict]
14201420

1421-
if len(resolved_archive_file) > 1:
1422-
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
1421+
if len(resolved_model_file) > 1:
1422+
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
14231423

1424-
for shard_file in resolved_archive_file:
1424+
for shard_file in resolved_model_file:
14251425
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
14261426

14271427
def _find_mismatched_keys(

0 commit comments

Comments
 (0)