Skip to content

Commit 3388009

Browse files
committed
AutoPipeline enhancements
1 parent 7ac6e28 commit 3388009

File tree

2 files changed

+1363
-17
lines changed

2 files changed

+1363
-17
lines changed

src/diffusers/loaders/single_file.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
308308
hosted on the Hub.
309309
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
310310
component configs in Diffusers format.
311+
checkpoint (`dict`, *optional*):
312+
The loaded state dictionary of the model.
311313
kwargs (remaining dictionary of keyword arguments, *optional*):
312314
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
313315
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -355,6 +357,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
355357
local_files_only = kwargs.pop("local_files_only", False)
356358
revision = kwargs.pop("revision", None)
357359
torch_dtype = kwargs.pop("torch_dtype", None)
360+
checkpoint = kwargs.pop("checkpoint", None)
358361

359362
is_legacy_loading = False
360363

@@ -375,15 +378,16 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
375378

376379
pipeline_class = _get_pipeline_class(cls, config=None)
377380

378-
checkpoint = load_single_file_checkpoint(
379-
pretrained_model_link_or_path,
380-
force_download=force_download,
381-
proxies=proxies,
382-
token=token,
383-
cache_dir=cache_dir,
384-
local_files_only=local_files_only,
385-
revision=revision,
386-
)
381+
if checkpoint is None:
382+
checkpoint = load_single_file_checkpoint(
383+
pretrained_model_link_or_path,
384+
force_download=force_download,
385+
proxies=proxies,
386+
token=token,
387+
cache_dir=cache_dir,
388+
local_files_only=local_files_only,
389+
revision=revision,
390+
)
387391

388392
if config is None:
389393
config = fetch_diffusers_config(checkpoint)

0 commit comments

Comments
 (0)