@@ -1153,6 +1153,8 @@ def download_from_original_stable_diffusion_ckpt(
1153
1153
controlnet : Optional [bool ] = None ,
1154
1154
adapter : Optional [bool ] = None ,
1155
1155
load_safety_checker : bool = True ,
1156
+ safety_checker : Optional [StableDiffusionSafetyChecker ] = None ,
1157
+ feature_extractor : Optional [AutoFeatureExtractor ] = None ,
1156
1158
pipeline_class : DiffusionPipeline = None ,
1157
1159
local_files_only = False ,
1158
1160
vae_path = None ,
@@ -1205,6 +1207,12 @@ def download_from_original_stable_diffusion_ckpt(
1205
1207
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
1206
1208
load_safety_checker (`bool`, *optional*, defaults to `True`):
1207
1209
Whether to load the safety checker or not. Defaults to `True`.
1210
+ safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`):
1211
+ Safety checker to use. If this parameter is `None`, the function will load a new instance of
1212
+ [StableDiffusionSafetyChecker] by itself, if needed.
1213
+ feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`):
1214
+ Feature extractor to use. If this parameter is `None`, the function will load a new instance of
1215
+ [AutoFeatureExtractor] by itself, if needed.
1208
1216
pipeline_class (`str`, *optional*, defaults to `None`):
1209
1217
The pipeline class to use. Pass `None` to determine automatically.
1210
1218
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -1530,8 +1538,8 @@ def download_from_original_stable_diffusion_ckpt(
1530
1538
unet = unet ,
1531
1539
scheduler = scheduler ,
1532
1540
controlnet = controlnet ,
1533
- safety_checker = None ,
1534
- feature_extractor = None ,
1541
+ safety_checker = safety_checker ,
1542
+ feature_extractor = feature_extractor ,
1535
1543
)
1536
1544
if hasattr (pipe , "requires_safety_checker" ):
1537
1545
pipe .requires_safety_checker = False
@@ -1551,8 +1559,8 @@ def download_from_original_stable_diffusion_ckpt(
1551
1559
unet = unet ,
1552
1560
scheduler = scheduler ,
1553
1561
low_res_scheduler = low_res_scheduler ,
1554
- safety_checker = None ,
1555
- feature_extractor = None ,
1562
+ safety_checker = safety_checker ,
1563
+ feature_extractor = feature_extractor ,
1556
1564
)
1557
1565
1558
1566
else :
@@ -1562,8 +1570,8 @@ def download_from_original_stable_diffusion_ckpt(
1562
1570
tokenizer = tokenizer ,
1563
1571
unet = unet ,
1564
1572
scheduler = scheduler ,
1565
- safety_checker = None ,
1566
- feature_extractor = None ,
1573
+ safety_checker = safety_checker ,
1574
+ feature_extractor = feature_extractor ,
1567
1575
)
1568
1576
if hasattr (pipe , "requires_safety_checker" ):
1569
1577
pipe .requires_safety_checker = False
@@ -1684,9 +1692,6 @@ def download_from_original_stable_diffusion_ckpt(
1684
1692
feature_extractor = AutoFeatureExtractor .from_pretrained (
1685
1693
"CompVis/stable-diffusion-safety-checker" , local_files_only = local_files_only
1686
1694
)
1687
- else :
1688
- safety_checker = None
1689
- feature_extractor = None
1690
1695
1691
1696
if controlnet :
1692
1697
pipe = pipeline_class (
0 commit comments