Skip to content

Commit cbfed0c

Browse files
[Config] Add optional arguments (#1395)
* Optional Components * uP * finish * finish * finish * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * up * Update src/diffusers/pipeline_utils.py * improve Co-authored-by: Pedro Cuenca <[email protected]>
1 parent e0e86b7 commit cbfed0c

15 files changed

+292
-63
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,13 @@ class DiffusionPipeline(ConfigMixin):
129129
130130
Class attributes:
131131
132-
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
132+
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
133133
components of the diffusion pipeline.
134+
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
135+
passed for the pipeline to function (should be overridden by subclasses).
134136
"""
135137
config_name = "model_index.json"
138+
_optional_components = []
136139

137140
def register_modules(self, **kwargs):
138141
# import it here to avoid circular import
@@ -184,12 +187,19 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
184187
model_index_dict.pop("_diffusers_version")
185188
model_index_dict.pop("_module", None)
186189

190+
expected_modules, optional_kwargs = self._get_signature_keys(self)
191+
192+
def is_saveable_module(name, value):
193+
if name not in expected_modules:
194+
return False
195+
if name in self._optional_components and value[0] is None:
196+
return False
197+
return True
198+
199+
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
200+
187201
for pipeline_component_name in model_index_dict.keys():
188202
sub_model = getattr(self, pipeline_component_name)
189-
if sub_model is None:
190-
# edge case for saving a pipeline with safety_checker=None
191-
continue
192-
193203
model_cls = sub_model.__class__
194204

195205
save_method_name = None
@@ -523,26 +533,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
523533
# some modules can be passed directly to the init
524534
# in this case they are already instantiated in `kwargs`
525535
# extract them here
526-
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
536+
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
527537
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
538+
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
528539

529540
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
530541

542+
# define init kwargs
543+
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
544+
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
545+
546+
# remove `null` components
547+
init_dict = {k: v for k, v in init_dict.items() if v[0] is not None}
548+
531549
if len(unused_kwargs) > 0:
532550
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
533551

534-
init_kwargs = {}
535-
536552
# import it here to avoid circular import
537553
from diffusers import pipelines
538554

539555
# 3. Load each module in the pipeline
540556
for name, (library_name, class_name) in init_dict.items():
541-
if class_name is None:
542-
# edge case for when the pipeline was saved with safety_checker=None
543-
init_kwargs[name] = None
544-
continue
545-
546557
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
547558
if class_name.startswith("Flax"):
548559
class_name = class_name[4:]
@@ -570,7 +581,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
570581
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
571582
f" {expected_class_obj}"
572583
)
573-
elif passed_class_obj[name] is None:
584+
elif passed_class_obj[name] is None and name not in pipeline_class._optional_components:
574585
logger.warning(
575586
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
576587
f" that this might lead to problems when using {pipeline_class} and is not recommended."
@@ -651,11 +662,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
651662

652663
# 4. Potentially add passed objects if expected
653664
missing_modules = set(expected_modules) - set(init_kwargs.keys())
654-
if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
665+
passed_modules = list(passed_class_obj.keys())
666+
optional_modules = pipeline_class._optional_components
667+
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
655668
for module in missing_modules:
656-
init_kwargs[module] = passed_class_obj[module]
669+
init_kwargs[module] = passed_class_obj.get(module, None)
657670
elif len(missing_modules) > 0:
658-
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
671+
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
659672
raise ValueError(
660673
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
661674
)
@@ -664,6 +677,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
664677
model = pipeline_class(**init_kwargs)
665678
return model
666679

680+
@staticmethod
681+
def _get_signature_keys(obj):
682+
parameters = inspect.signature(obj.__init__).parameters
683+
required_parameters = {k: v for k, v in parameters.items() if v.default is not True}
684+
optional_parameters = set({k for k, v in parameters.items() if v.default is True})
685+
expected_modules = set(required_parameters.keys()) - set(["self"])
686+
return expected_modules, optional_parameters
687+
667688
@property
668689
def components(self) -> Dict[str, Any]:
669690
r"""
@@ -688,8 +709,10 @@ def components(self) -> Dict[str, Any]:
688709
Returns:
689710
A dictionaly containing all the modules needed to initialize the pipeline.
690711
"""
691-
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
692-
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
712+
expected_modules, optional_parameters = self._get_signature_keys(self)
713+
components = {
714+
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
715+
}
693716

694717
if set(components.keys()) != expected_modules:
695718
raise ValueError(

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
6767
feature_extractor ([`CLIPFeatureExtractor`]):
6868
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
6969
"""
70+
_optional_components = ["safety_checker", "feature_extractor"]
7071

7172
def __init__(
7273
self,
@@ -84,6 +85,7 @@ def __init__(
8485
],
8586
safety_checker: StableDiffusionSafetyChecker,
8687
feature_extractor: CLIPFeatureExtractor,
88+
requires_safety_checker: bool = True,
8789
):
8890
super().__init__()
8991

@@ -114,7 +116,7 @@ def __init__(
114116
new_config["clip_sample"] = False
115117
scheduler._internal_dict = FrozenDict(new_config)
116118

117-
if safety_checker is None:
119+
if safety_checker is None and requires_safety_checker:
118120
logger.warning(
119121
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
120122
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
@@ -124,6 +126,12 @@ def __init__(
124126
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
125127
)
126128

129+
if safety_checker is not None and feature_extractor is None:
130+
raise ValueError(
131+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
132+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
133+
)
134+
127135
self.register_modules(
128136
vae=vae,
129137
text_encoder=text_encoder,
@@ -133,6 +141,7 @@ def __init__(
133141
safety_checker=safety_checker,
134142
feature_extractor=feature_extractor,
135143
)
144+
self.register_to_config(requires_safety_checker=requires_safety_checker)
136145

137146
def enable_xformers_memory_efficient_attention(self):
138147
r"""

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
8080
feature_extractor ([`CLIPFeatureExtractor`]):
8181
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
8282
"""
83+
_optional_components = ["safety_checker", "feature_extractor"]
8384

8485
def __init__(
8586
self,
@@ -97,6 +98,7 @@ def __init__(
9798
],
9899
safety_checker: StableDiffusionSafetyChecker,
99100
feature_extractor: CLIPFeatureExtractor,
101+
requires_safety_checker: bool = True,
100102
):
101103
super().__init__()
102104

@@ -127,7 +129,7 @@ def __init__(
127129
new_config["clip_sample"] = False
128130
scheduler._internal_dict = FrozenDict(new_config)
129131

130-
if safety_checker is None:
132+
if safety_checker is None and requires_safety_checker:
131133
logger.warning(
132134
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
133135
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
@@ -137,6 +139,12 @@ def __init__(
137139
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
138140
)
139141

142+
if safety_checker is not None and feature_extractor is None:
143+
raise ValueError(
144+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
145+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
146+
)
147+
140148
self.register_modules(
141149
vae=vae,
142150
text_encoder=text_encoder,
@@ -146,6 +154,7 @@ def __init__(
146154
safety_checker=safety_checker,
147155
feature_extractor=feature_extractor,
148156
)
157+
self.register_to_config(requires_safety_checker=requires_safety_checker)
149158

150159
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
151160
r"""

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
132132
feature_extractor ([`CLIPFeatureExtractor`]):
133133
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
134134
"""
135+
_optional_components = ["safety_checker", "feature_extractor"]
135136

136137
def __init__(
137138
self,
@@ -142,6 +143,7 @@ def __init__(
142143
scheduler: DDIMScheduler,
143144
safety_checker: StableDiffusionSafetyChecker,
144145
feature_extractor: CLIPFeatureExtractor,
146+
requires_safety_checker: bool = True,
145147
):
146148
super().__init__()
147149

@@ -159,7 +161,7 @@ def __init__(
159161
new_config["steps_offset"] = 1
160162
scheduler._internal_dict = FrozenDict(new_config)
161163

162-
if safety_checker is None:
164+
if safety_checker is None and requires_safety_checker:
163165
logger.warning(
164166
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
165167
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
@@ -169,6 +171,12 @@ def __init__(
169171
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
170172
)
171173

174+
if safety_checker is not None and feature_extractor is None:
175+
raise ValueError(
176+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
177+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
178+
)
179+
172180
self.register_modules(
173181
vae=vae,
174182
text_encoder=text_encoder,
@@ -178,6 +186,7 @@ def __init__(
178186
safety_checker=safety_checker,
179187
feature_extractor=feature_extractor,
180188
)
189+
self.register_to_config(requires_safety_checker=requires_safety_checker)
181190

182191
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
183192
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
5252
safety_checker: OnnxRuntimeModel,
5353
feature_extractor: CLIPFeatureExtractor,
54+
requires_safety_checker: bool = True,
5455
):
5556
super().__init__()
5657

@@ -81,6 +82,22 @@ def __init__(
8182
new_config["clip_sample"] = False
8283
scheduler._internal_dict = FrozenDict(new_config)
8384

85+
if safety_checker is None and requires_safety_checker:
86+
logger.warning(
87+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
88+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
89+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
90+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
91+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
92+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
93+
)
94+
95+
if safety_checker is not None and feature_extractor is None:
96+
raise ValueError(
97+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
98+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
99+
)
100+
84101
self.register_modules(
85102
vae_encoder=vae_encoder,
86103
vae_decoder=vae_decoder,
@@ -91,6 +108,7 @@ def __init__(
91108
safety_checker=safety_checker,
92109
feature_extractor=feature_extractor,
93110
)
111+
self.register_to_config(requires_safety_checker=requires_safety_checker)
94112

95113
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
96114
r"""

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
8888
safety_checker: OnnxRuntimeModel,
8989
feature_extractor: CLIPFeatureExtractor,
90+
requires_safety_checker: bool = True,
9091
):
9192
super().__init__()
9293

@@ -117,7 +118,7 @@ def __init__(
117118
new_config["clip_sample"] = False
118119
scheduler._internal_dict = FrozenDict(new_config)
119120

120-
if safety_checker is None:
121+
if safety_checker is None and requires_safety_checker:
121122
logger.warning(
122123
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
123124
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
@@ -127,6 +128,12 @@ def __init__(
127128
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
128129
)
129130

131+
if safety_checker is not None and feature_extractor is None:
132+
raise ValueError(
133+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
134+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
135+
)
136+
130137
self.register_modules(
131138
vae_encoder=vae_encoder,
132139
vae_decoder=vae_decoder,
@@ -137,6 +144,7 @@ def __init__(
137144
safety_checker=safety_checker,
138145
feature_extractor=feature_extractor,
139146
)
147+
self.register_to_config(requires_safety_checker=requires_safety_checker)
140148

141149
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
142150
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
101101
safety_checker: OnnxRuntimeModel,
102102
feature_extractor: CLIPFeatureExtractor,
103+
requires_safety_checker: bool = True,
103104
):
104105
super().__init__()
105106
logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
@@ -131,7 +132,7 @@ def __init__(
131132
new_config["clip_sample"] = False
132133
scheduler._internal_dict = FrozenDict(new_config)
133134

134-
if safety_checker is None:
135+
if safety_checker is None and requires_safety_checker:
135136
logger.warning(
136137
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
137138
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
@@ -141,6 +142,12 @@ def __init__(
141142
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
142143
)
143144

145+
if safety_checker is not None and feature_extractor is None:
146+
raise ValueError(
147+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
148+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
149+
)
150+
144151
self.register_modules(
145152
vae_encoder=vae_encoder,
146153
vae_decoder=vae_decoder,
@@ -151,6 +158,7 @@ def __init__(
151158
safety_checker=safety_checker,
152159
feature_extractor=feature_extractor,
153160
)
161+
self.register_to_config(requires_safety_checker=requires_safety_checker)
154162

155163
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
156164
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):

0 commit comments

Comments
 (0)