Skip to content

Commit 7e39516

Browse files
w4ffl35sayakpaul
andauthored
Allow more arguments to be passed to convert_from_ckpt (#7222)
Allow safety and feature extractor arguments to be passed to convert_from_ckpt Allows management of safety checker and feature extractor from outside of the convert ckpt class. Co-authored-by: Sayak Paul <[email protected]>
1 parent 56a7608 commit 7e39516

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,8 @@ def download_from_original_stable_diffusion_ckpt(
11531153
controlnet: Optional[bool] = None,
11541154
adapter: Optional[bool] = None,
11551155
load_safety_checker: bool = True,
1156+
safety_checker: Optional[StableDiffusionSafetyChecker] = None,
1157+
feature_extractor: Optional[AutoFeatureExtractor] = None,
11561158
pipeline_class: DiffusionPipeline = None,
11571159
local_files_only=False,
11581160
vae_path=None,
@@ -1205,6 +1207,12 @@ def download_from_original_stable_diffusion_ckpt(
12051207
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
12061208
load_safety_checker (`bool`, *optional*, defaults to `True`):
12071209
Whether to load the safety checker or not. Defaults to `True`.
1210+
safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`):
1211+
Safety checker to use. If this parameter is `None`, the function will load a new instance of
1212+
[StableDiffusionSafetyChecker] by itself, if needed.
1213+
feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`):
1214+
Feature extractor to use. If this parameter is `None`, the function will load a new instance of
1215+
[AutoFeatureExtractor] by itself, if needed.
12081216
pipeline_class (`str`, *optional*, defaults to `None`):
12091217
The pipeline class to use. Pass `None` to determine automatically.
12101218
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -1530,8 +1538,8 @@ def download_from_original_stable_diffusion_ckpt(
15301538
unet=unet,
15311539
scheduler=scheduler,
15321540
controlnet=controlnet,
1533-
safety_checker=None,
1534-
feature_extractor=None,
1541+
safety_checker=safety_checker,
1542+
feature_extractor=feature_extractor,
15351543
)
15361544
if hasattr(pipe, "requires_safety_checker"):
15371545
pipe.requires_safety_checker = False
@@ -1551,8 +1559,8 @@ def download_from_original_stable_diffusion_ckpt(
15511559
unet=unet,
15521560
scheduler=scheduler,
15531561
low_res_scheduler=low_res_scheduler,
1554-
safety_checker=None,
1555-
feature_extractor=None,
1562+
safety_checker=safety_checker,
1563+
feature_extractor=feature_extractor,
15561564
)
15571565

15581566
else:
@@ -1562,8 +1570,8 @@ def download_from_original_stable_diffusion_ckpt(
15621570
tokenizer=tokenizer,
15631571
unet=unet,
15641572
scheduler=scheduler,
1565-
safety_checker=None,
1566-
feature_extractor=None,
1573+
safety_checker=safety_checker,
1574+
feature_extractor=feature_extractor,
15671575
)
15681576
if hasattr(pipe, "requires_safety_checker"):
15691577
pipe.requires_safety_checker = False
@@ -1684,9 +1692,6 @@ def download_from_original_stable_diffusion_ckpt(
16841692
feature_extractor = AutoFeatureExtractor.from_pretrained(
16851693
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
16861694
)
1687-
else:
1688-
safety_checker = None
1689-
feature_extractor = None
16901695

16911696
if controlnet:
16921697
pipe = pipeline_class(

0 commit comments

Comments
 (0)