Skip to content

Commit c40f60c

Browse files
committed
update
1 parent 04d7dc3 commit c40f60c

File tree

2 files changed

+79
-88
lines changed

2 files changed

+79
-88
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
141141
return True
142142

143143

144-
def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) -> Union[List[os.PathLike], str]:
144+
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
145145
weight_names = [
146146
WEIGHTS_NAME,
147147
SAFETENSORS_WEIGHTS_NAME,
@@ -177,17 +177,9 @@ def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) -
177177
# `text_encoder/pytorch_model.bin.index.json`
178178
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
179179

180-
def filter_for_compatible_extensions(filenames, variant=None, use_safetensors=True):
181-
def is_safetensors(filename):
182-
return ".safetensors" in filename
183-
184-
def is_not_safetensors(filename):
185-
return ".safetensors" not in filename
186-
187-
if use_safetensors and is_safetensors_compatible(filenames):
188-
extension_filter = is_safetensors
189-
else:
190-
extension_filter = is_not_safetensors
180+
def filter_for_compatible_extensions(filenames, variant=None, ignore_patterns=None):
181+
def extension_filter(f):
182+
return not any(f.endswith(pattern) for pattern in ignore_patterns)
191183

192184
tensor_files = {f for f in filenames if extension_filter(f)}
193185
non_variant_indexes = {
@@ -222,7 +214,7 @@ def filter_for_weights_and_indexes(filenames, file_re, index_re):
222214
variant_filenames = set()
223215
for component, component_filenames in components.items():
224216
component_filenames = filter_for_compatible_extensions(
225-
component_filenames, variant=variant, use_safetensors=use_safetensors
217+
component_filenames, variant=variant, ignore_patterns=ignore_patterns
226218
)
227219

228220
component_variants = set()
@@ -239,6 +231,18 @@ def filter_for_weights_and_indexes(filenames, file_re, index_re):
239231
)
240232
usable_filenames.update(component_non_variants)
241233

234+
if len(variant_filenames) == 0 and variant is not None:
235+
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
236+
raise ValueError(error_message)
237+
238+
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
239+
logger.warning(
240+
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
241+
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
242+
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
243+
f"expected, please check your folder structure."
244+
)
245+
242246
return usable_filenames, variant_filenames
243247

244248

@@ -933,18 +937,13 @@ def _get_custom_components_and_folders(
933937
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
934938
)
935939

936-
if len(variant_filenames) == 0 and variant is not None:
937-
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
938-
raise ValueError(error_message)
939-
940940
return custom_components, folder_names
941941

942942

943943
def _get_ignore_patterns(
944944
passed_components,
945945
model_folder_names: List[str],
946946
model_filenames: List[str],
947-
variant_filenames: List[str],
948947
use_safetensors: bool,
949948
from_flax: bool,
950949
allow_pickle: bool,
@@ -975,33 +974,13 @@ def _get_ignore_patterns(
975974
if not use_onnx:
976975
ignore_patterns += ["*.onnx", "*.pb"]
977976

978-
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
979-
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
980-
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
981-
logger.warning(
982-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
983-
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
984-
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
985-
f"expected, please check your folder structure."
986-
)
987-
988977
else:
989978
ignore_patterns = ["*.safetensors", "*.msgpack"]
990979

991980
use_onnx = use_onnx if use_onnx is not None else is_onnx
992981
if not use_onnx:
993982
ignore_patterns += ["*.onnx", "*.pb"]
994983

995-
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
996-
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
997-
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
998-
logger.warning(
999-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
1000-
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
1001-
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
1002-
f"your folder structure."
1003-
)
1004-
1005984
return ignore_patterns
1006985

1007986

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,10 +1343,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13431343
revision=revision,
13441344
)
13451345

1346-
allow_pickle = False
1347-
if use_safetensors is None:
1348-
use_safetensors = True
1349-
allow_pickle = True
1346+
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
1347+
use_safetensors = use_safetensors if use_safetensors is not None else True
13501348

13511349
allow_patterns = None
13521350
ignore_patterns = None
@@ -1361,6 +1359,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13611359
model_info_call_error = e # save error to reraise it if model is not cached locally
13621360

13631361
if not local_files_only:
1362+
config_file = hf_hub_download(
1363+
pretrained_model_name,
1364+
cls.config_name,
1365+
cache_dir=cache_dir,
1366+
revision=revision,
1367+
proxies=proxies,
1368+
force_download=force_download,
1369+
token=token,
1370+
)
1371+
config_dict = cls._dict_from_json_file(config_file)
1372+
ignore_filenames = config_dict.pop("_ignore_files", [])
1373+
13641374
filenames = {sibling.rfilename for sibling in info.siblings}
13651375
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
13661376
warn_msg = (
@@ -1375,61 +1385,20 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13751385
)
13761386
logger.warning(warn_msg)
13771387

1378-
model_filenames, variant_filenames = variant_compatible_siblings(
1379-
filenames, variant=variant, use_safetensors=use_safetensors
1380-
)
1381-
1382-
config_file = hf_hub_download(
1383-
pretrained_model_name,
1384-
cls.config_name,
1385-
cache_dir=cache_dir,
1386-
revision=revision,
1387-
proxies=proxies,
1388-
force_download=force_download,
1389-
token=token,
1390-
)
1391-
1392-
config_dict = cls._dict_from_json_file(config_file)
1393-
ignore_filenames = config_dict.pop("_ignore_files", [])
1394-
1395-
# remove ignored filenames
1396-
model_filenames = set(model_filenames) - set(ignore_filenames)
1397-
variant_filenames = set(variant_filenames) - set(ignore_filenames)
1398-
1388+
filenames = set(filenames) - set(ignore_filenames)
13991389
if revision in DEPRECATED_REVISION_ARGS and version.parse(
14001390
version.parse(__version__).base_version
14011391
) >= version.parse("0.22.0"):
1402-
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
1392+
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
14031393

14041394
custom_components, folder_names = _get_custom_components_and_folders(
1405-
pretrained_model_name, config_dict, filenames, variant_filenames, variant
1395+
pretrained_model_name, config_dict, filenames, variant
14061396
)
1407-
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
1408-
14091397
custom_class_name = None
14101398
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
14111399
custom_pipeline = config_dict["_class_name"][0]
14121400
custom_class_name = config_dict["_class_name"][1]
14131401

1414-
# all filenames compatible with variant will be added
1415-
allow_patterns = list(model_filenames)
1416-
1417-
# allow all patterns from non-model folders
1418-
# this enables downloading schedulers, tokenizers, ...
1419-
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
1420-
# add custom component files
1421-
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
1422-
# add custom pipeline file
1423-
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
1424-
# also allow downloading config.json files with the model
1425-
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
1426-
allow_patterns += [
1427-
SCHEDULER_CONFIG_NAME,
1428-
CONFIG_NAME,
1429-
cls.config_name,
1430-
CUSTOM_PIPELINE_FILE_NAME,
1431-
]
1432-
14331402
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
14341403
load_components_from_hub = len(custom_components) > 0
14351404

@@ -1446,6 +1415,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14461415
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
14471416
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
14481417
)
1418+
model_folder_names = {os.path.split(f)[0] for f in filenames if os.path.split(f)[0] in folder_names}
14491419

14501420
# retrieve passed components that should not be downloaded
14511421
pipeline_class = _get_pipeline_class(
@@ -1466,8 +1436,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14661436
ignore_patterns = _get_ignore_patterns(
14671437
passed_components,
14681438
model_folder_names,
1469-
model_filenames,
1470-
variant_filenames,
1439+
filenames,
14711440
use_safetensors,
14721441
from_flax,
14731442
allow_pickle,
@@ -1476,6 +1445,49 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14761445
variant,
14771446
)
14781447

1448+
model_filenames, variant_filenames = variant_compatible_siblings(
1449+
filenames, variant=variant, ignore_patterns=ignore_patterns
1450+
)
1451+
1452+
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
1453+
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
1454+
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
1455+
logger.warning(
1456+
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
1457+
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
1458+
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
1459+
f"expected, please check your folder structure."
1460+
)
1461+
1462+
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
1463+
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
1464+
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
1465+
logger.warning(
1466+
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
1467+
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
1468+
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
1469+
f"your folder structure."
1470+
)
1471+
1472+
# all filenames compatible with variant will be added
1473+
allow_patterns = list(model_filenames)
1474+
1475+
# allow all patterns from non-model folders
1476+
# this enables downloading schedulers, tokenizers, ...
1477+
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
1478+
# add custom component files
1479+
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
1480+
# add custom pipeline file
1481+
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
1482+
# also allow downloading config.json files with the model
1483+
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
1484+
allow_patterns += [
1485+
SCHEDULER_CONFIG_NAME,
1486+
CONFIG_NAME,
1487+
cls.config_name,
1488+
CUSTOM_PIPELINE_FILE_NAME,
1489+
]
1490+
14791491
# Don't download any objects that are passed
14801492
allow_patterns = [
14811493
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)

0 commit comments

Comments
 (0)