Skip to content

Commit 4942017

Browse files
committed
Add the gated option.
1 parent 48bec80 commit 4942017

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ def search_huggingface(search_word: str, **kwargs):
566566
Tag to filter models by pipeline.
567567
token (`str`, *optional*):
568568
API token for Hugging Face authentication.
569+
gated (`bool`, *optional*, defaults to `False` ):
570+
A boolean to filter models on the Hub that are gated or not.
569571
skip_error (`bool`, *optional*, defaults to `False`):
570572
Whether to skip errors and return None.
571573
@@ -580,6 +582,7 @@ def search_huggingface(search_word: str, **kwargs):
580582
include_params = kwargs.pop("include_params", False)
581583
pipeline_tag = kwargs.pop("pipeline_tag", None)
582584
token = kwargs.pop("token", None)
585+
gated = kwargs.pop("gated", False)
583586
skip_error = kwargs.pop("skip_error", False)
584587

585588
# Get the type and loading method for the keyword
@@ -622,6 +625,7 @@ def search_huggingface(search_word: str, **kwargs):
622625
fetch_config=True,
623626
pipeline_tag=pipeline_tag,
624627
full=True,
628+
gated=gated,
625629
token=token
626630
)
627631
model_dicts = [asdict(value) for value in list(hf_models)]
@@ -1273,6 +1277,8 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
12731277
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
12741278
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
12751279
weights. If set to `False`, safetensors weights are not loaded.
1280+
gated (`bool`, *optional*, defaults to `False` ):
1281+
A boolean to filter models on the Hub that are gated or not.
12761282
kwargs (remaining dictionary of keyword arguments, *optional*):
12771283
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
12781284
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -1793,6 +1799,8 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
17931799
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
17941800
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
17951801
weights. If set to `False`, safetensors weights are not loaded.
1802+
gated (`bool`, *optional*, defaults to `False` ):
1803+
A boolean to filter models on the Hub that are gated or not.
17961804
kwargs (remaining dictionary of keyword arguments, *optional*):
17971805
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
17981806
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -2307,6 +2315,8 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
23072315
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
23082316
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
23092317
weights. If set to `False`, safetensors weights are not loaded.
2318+
gated (`bool`, *optional*, defaults to `False` ):
2319+
A boolean to filter models on the Hub that are gated or not.
23102320
kwargs (remaining dictionary of keyword arguments, *optional*):
23112321
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
23122322
class). The overwritten components are passed directly to the pipelines `__init__` method. See example

0 commit comments

Comments
 (0)