Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
checkpoint (`dict`, *optional*):
The loaded state dictionary of the model.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
Expand Down Expand Up @@ -355,6 +357,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
checkpoint = kwargs.pop("checkpoint", None)

is_legacy_loading = False

Expand All @@ -375,15 +378,16 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):

pipeline_class = _get_pipeline_class(cls, config=None)

checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if checkpoint is None:
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)

if config is None:
config = fetch_diffusers_config(checkpoint)
Expand Down
Loading