Skip to content

Commit 502fb28

Browse files
committed
fix
1 parent 29d9fb5 commit 502fb28

File tree

3 files changed

+40
-13
lines changed

3 files changed

+40
-13
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,3 +1986,19 @@ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
19861986
else:
19871987
self.vae.unfuse_qkv_projections()
19881988
self.fusing_vae = False
1989+
1990+
def safety_checker_Level(self, Level):
1991+
"""
1992+
Adjust the filter intensity.
1993+
1994+
Args:
1995+
Level (`int` or `float` or one of the following [`WEAK`], [`MEDIUM`], [`NOMAL`], [`STRONG`], [`MAX`])
1996+
"""
1997+
_safety_checker = getattr(self, "safety_checker", None)
1998+
if _safety_checker is not None:
1999+
if hasattr(_safety_checker, "update_safety_checker_Level"):
2000+
self.safety_checker.update_safety_checker_Level(Level)
2001+
else:
2002+
logger.warning("`safety_checker_Level` is ignored because `update_safety_checker_Level` is not in `safety_checker`.")
2003+
else:
2004+
logger.warning("Since there is no `safety_checker`, `safety_checker_Level` is ignored.")

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,12 +563,6 @@ def run_safety_checker(self, image, device, dtype):
563563
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
564564
)
565565
return image, has_nsfw_concept
566-
567-
def safety_checker_Level(self, Level):
568-
if self.safety_checker is not None:
569-
self.safety_checker.update_safety_checker_Level(Level)
570-
else:
571-
logger.warning("`safety_checker_Level` is ignored because `safety_checker=None` is passed.")
572566

573567
def decode_latents(self, latents):
574568
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"

src/diffusers/pipelines/stable_diffusion/safety_checker.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
3434
main_input_name = "clip_input"
3535

3636
_no_split_modules = ["CLIPEncoderLayer"]
37+
3738

3839
def __init__(self, config: CLIPConfig):
3940
super().__init__(config)
@@ -46,8 +47,14 @@ def __init__(self, config: CLIPConfig):
4647

4748
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
4849
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
50+
51+
self.adjustment = 0.0
4952

5053
def update_safety_checker_Level(self, Level):
54+
"""
55+
Args:
56+
Level (`int` or `float` or one of the following [`WEAK`], [`MEDIUM`], [`NOMAL`], [`STRONG`], [`MAX`])
57+
"""
5158
Level_dict = {
5259
"WEAK": -1.0,
5360
"MEDIUM": -0.5,
@@ -56,11 +63,21 @@ def update_safety_checker_Level(self, Level):
5663
"MAX": 1.0,
5764
}
5865
if Level in Level_dict:
59-
Level = Level_dict[Level]
66+
Level = Level_dict[Level]
6067
if isinstance(Level, (float, int)):
6168
setattr(self,"adjustment",Level)
6269
else:
6370
raise ValueError("`int` or `float` or one of the following ['WEAK'], ['MEDIUM'], ['NOMAL'], ['STRONG'], ['MAX']")
71+
72+
if self.adjustment<0:
73+
logger.warning(
74+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
75+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
76+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
77+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
78+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
79+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
80+
)
6481

6582
@torch.no_grad()
6683
def forward(self, clip_input, images):
@@ -78,20 +95,20 @@ def forward(self, clip_input, images):
7895

7996
# increase this value to create a stronger `nfsw` filter
8097
# at the cost of increasing the possibility of filtering benign images
81-
adjustment = 0.0
98+
#adjustment = 0.0
8299

83100
for concept_idx in range(len(special_cos_dist[0])):
84101
concept_cos = special_cos_dist[i][concept_idx]
85102
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
86-
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
103+
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + self.adjustment, 3)
87104
if result_img["special_scores"][concept_idx] > 0:
88105
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
89-
adjustment = 0.01
106+
self.adjustment = 0.01
90107

91108
for concept_idx in range(len(cos_dist[0])):
92109
concept_cos = cos_dist[i][concept_idx]
93110
concept_threshold = self.concept_embeds_weights[concept_idx].item()
94-
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
111+
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + self.adjustment, 3)
95112
if result_img["concept_scores"][concept_idx] > 0:
96113
result_img["bad_concepts"].append(concept_idx)
97114

@@ -124,9 +141,9 @@ def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor):
124141

125142
# increase this value to create a stronger `nsfw` filter
126143
# at the cost of increasing the possibility of filtering benign images
127-
adjustment = 0.0
144+
#adjustment = 0.0
128145

129-
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
146+
special_scores = special_cos_dist - self.special_care_embeds_weights + self.adjustment
130147
# special_scores = special_scores.round(decimals=3)
131148
special_care = torch.any(special_scores > 0, dim=1)
132149
special_adjustment = special_care * 0.01

0 commit comments

Comments
 (0)