Skip to content

Commit 29d9fb5

Browse files
committed
update
1 parent a054c78 commit 29d9fb5

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,12 @@ def run_safety_checker(self, image, device, dtype):
563563
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
564564
)
565565
return image, has_nsfw_concept
566+
567+
def safety_checker_Level(self, Level):
568+
if self.safety_checker is not None:
569+
self.safety_checker.update_safety_checker_Level(Level)
570+
else:
571+
logger.warning("`safety_checker_Level` is ignored because `safety_checker=None` is passed.")
566572

567573
def decode_latents(self, latents):
568574
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"

src/diffusers/pipelines/stable_diffusion/safety_checker.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ def __init__(self, config: CLIPConfig):
4646

4747
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
4848
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
49+
50+
def update_safety_checker_Level(self, Level):
51+
Level_dict = {
52+
"WEAK": -1.0,
53+
"MEDIUM": -0.5,
54+
"NOMAL": 0.0,
55+
"STRONG": 0.5,
56+
"MAX": 1.0,
57+
}
58+
if Level in Level_dict:
59+
Level = Level_dict[Level]
60+
if isinstance(Level, (float, int)):
61+
setattr(self,"adjustment",Level)
62+
else:
63+
raise ValueError("`int` or `float` or one of the following ['WEAK'], ['MEDIUM'], ['NOMAL'], ['STRONG'], ['MAX']")
4964

5065
@torch.no_grad()
5166
def forward(self, clip_input, images):

0 commit comments

Comments
 (0)