Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
29d9fb5
update
suzukimain Aug 3, 2024
502fb28
fix
suzukimain Aug 3, 2024
3eb0186
fix
suzukimain Aug 3, 2024
1b68031
fix
suzukimain Aug 3, 2024
c1a96a0
fix
suzukimain Aug 3, 2024
12ea447
Merge branch 'main' into safety_checker
suzukimain Sep 11, 2024
b586bc9
Update warning messages
suzukimain Sep 11, 2024
2261c9e
Merge branch 'main' into safety_checker
suzukimain Oct 3, 2024
cfdb4c2
Merge branch 'main' into safety_checker
suzukimain Oct 4, 2024
d7963f8
Merge branch 'main' into safety_checker
suzukimain Oct 21, 2024
567f77a
Merge branch 'main' into safety_checker
suzukimain Oct 23, 2024
53e5995
Merge branch 'main' into safety_checker
suzukimain Oct 26, 2024
4dc10fb
Merge branch 'main' into safety_checker
suzukimain Oct 30, 2024
b52e22d
Merge branch 'main' into safety_checker
suzukimain Nov 8, 2024
190a641
Merge branch 'main' into safety_checker
suzukimain Nov 17, 2024
a093626
Update src/diffusers/pipelines/stable_diffusion/safety_checker.py
suzukimain Dec 13, 2024
3c734d4
Update src/diffusers/pipelines/stable_diffusion/safety_checker.py
suzukimain Dec 13, 2024
1065c60
Update src/diffusers/pipelines/stable_diffusion/safety_checker.py
suzukimain Dec 13, 2024
51fc901
Update threshold dictionary
suzukimain Dec 13, 2024
086ec88
Merge branch 'safety_checker' of https://github.com/suzukimain/diffus…
suzukimain Dec 13, 2024
5b12d86
Merge branch 'main' into safety_checker
suzukimain Dec 13, 2024
571f48b
Merge branch 'main' into safety_checker
suzukimain Dec 13, 2024
d92ab90
Merge branch 'huggingface:main' into safety_checker
suzukimain Dec 15, 2024
cd4ce6e
Code formatting
suzukimain Dec 15, 2024
ee2f922
fix
suzukimain Dec 15, 2024
9b5ed48
fix
suzukimain Dec 15, 2024
3c7842c
Merge branch 'main' into safety_checker
suzukimain Dec 17, 2024
f13f798
Merge branch 'main' into safety_checker
suzukimain Dec 22, 2024
6a1532b
Change values for adjusting filtering strength
suzukimain Dec 22, 2024
c158e09
Merge branch 'main' into safety_checker
suzukimain Dec 23, 2024
725d0c0
Merge branch 'main' into safety_checker
suzukimain Dec 23, 2024
b3735e0
Merge branch 'main' into safety_checker
suzukimain Dec 23, 2024
906a7d6
Merge branch 'main' into safety_checker
hlky Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,3 +1929,26 @@ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False

def safety_checker_level(self, level):
"""
Adjust the filter intensity.

Args:
level (`int` or `float` or one of the following [`WEAK`], [`MEDIUM`], [`NOMAL`], [`STRONG`], [`MAX`])
"""
_safety_checker = getattr(self, "safety_checker", None)
if _safety_checker is not None:
if hasattr(_safety_checker, "update_safety_checker_Level"):
self.safety_checker.update_safety_checker_Level(level)
else:
logger.warning("`safety_checker_level` is ignored because `update_safety_checker_Level` is not in `safety_checker`.")
else:
logger.warning("Since there is no `safety_checker`, `safety_checker_level` is ignored.")

def filter_level(self):
"""
Return:
`int` ,`float` or None
"""
return getattr(getattr(self,"safety_checker",None), "adjustment", None)
44 changes: 38 additions & 6 deletions src/diffusers/pipelines/stable_diffusion/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,38 @@ def __init__(self, config: CLIPConfig):

self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)

self.adjustment = 0.0

def update_safety_checker_Level(self, Level):
"""
Args:
Level (`int` or `float` or one of the following [`WEAK`], [`MEDIUM`], [`NOMAL`], [`STRONG`], [`MAX`])
"""
Level_dict = {
"WEAK": -1.0,
"MEDIUM": -0.5,
"NOMAL": 0.0,
"STRONG": 0.5,
"MAX": 1.0,
}
if Level in Level_dict:
Level = Level_dict[Level]
if isinstance(Level, (float, int)):
setattr(self,"adjustment",Level)
else:
raise ValueError("`int` or `float` or one of the following ['WEAK'], ['MEDIUM'], ['NOMAL'], ['STRONG'], ['MAX']")

if self.adjustment<0:
logger.warning(
f"You have weakened the filtering strength of safety checker. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" When reducing the filtering strength, take the same action as when disabling the safety checker."
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)

@torch.no_grad()
def forward(self, clip_input, images):
Expand All @@ -63,20 +95,20 @@ def forward(self, clip_input, images):

# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
#adjustment = 0.0

for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + self.adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
self.adjustment = 0.01

for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + self.adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)

Expand Down Expand Up @@ -109,9 +141,9 @@ def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor):

# increase this value to create a stronger `nsfw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
#adjustment = 0.0

special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
special_scores = special_cos_dist - self.special_care_embeds_weights + self.adjustment
# special_scores = special_scores.round(decimals=3)
special_care = torch.any(special_scores > 0, dim=1)
special_adjustment = special_care * 0.01
Expand Down