@@ -34,6 +34,7 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
34
34
main_input_name = "clip_input"
35
35
36
36
_no_split_modules = ["CLIPEncoderLayer" ]
37
+
37
38
38
39
def __init__ (self , config : CLIPConfig ):
39
40
super ().__init__ (config )
@@ -46,8 +47,14 @@ def __init__(self, config: CLIPConfig):
46
47
47
48
self .concept_embeds_weights = nn .Parameter (torch .ones (17 ), requires_grad = False )
48
49
self .special_care_embeds_weights = nn .Parameter (torch .ones (3 ), requires_grad = False )
50
+
51
+ self .adjustment = 0.0
49
52
50
53
def update_safety_checker_Level (self , Level ):
54
+ """
55
+ Args:
56
+ Level (`int` or `float` or one of the following [`WEAK`], [`MEDIUM`], [`NOMAL`], [`STRONG`], [`MAX`])
57
+ """
51
58
Level_dict = {
52
59
"WEAK" : - 1.0 ,
53
60
"MEDIUM" : - 0.5 ,
@@ -56,11 +63,21 @@ def update_safety_checker_Level(self, Level):
56
63
"MAX" : 1.0 ,
57
64
}
58
65
if Level in Level_dict :
59
- Level = Level_dict [Level ]
66
+ Level = Level_dict [Level ]
60
67
if isinstance (Level , (float , int )):
61
68
setattr (self ,"adjustment" ,Level )
62
69
else :
63
70
raise ValueError ("`int` or `float` or one of the following ['WEAK'], ['MEDIUM'], ['NOMAL'], ['STRONG'], ['MAX']" )
71
+
72
+ if self .adjustment < 0 :
73
+ logger .warning (
74
+ f"You have disabled the safety checker for { self .__class__ } by passing `safety_checker=None`. Ensure"
75
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
76
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
77
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
78
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
79
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
80
+ )
64
81
65
82
@torch .no_grad ()
66
83
def forward (self , clip_input , images ):
@@ -78,20 +95,20 @@ def forward(self, clip_input, images):
78
95
79
96
# increase this value to create a stronger `nfsw` filter
80
97
# at the cost of increasing the possibility of filtering benign images
81
- adjustment = 0.0
98
+ # adjustment = 0.0
82
99
83
100
for concept_idx in range (len (special_cos_dist [0 ])):
84
101
concept_cos = special_cos_dist [i ][concept_idx ]
85
102
concept_threshold = self .special_care_embeds_weights [concept_idx ].item ()
86
- result_img ["special_scores" ][concept_idx ] = round (concept_cos - concept_threshold + adjustment , 3 )
103
+ result_img ["special_scores" ][concept_idx ] = round (concept_cos - concept_threshold + self . adjustment , 3 )
87
104
if result_img ["special_scores" ][concept_idx ] > 0 :
88
105
result_img ["special_care" ].append ({concept_idx , result_img ["special_scores" ][concept_idx ]})
89
- adjustment = 0.01
106
+ self . adjustment = 0.01
90
107
91
108
for concept_idx in range (len (cos_dist [0 ])):
92
109
concept_cos = cos_dist [i ][concept_idx ]
93
110
concept_threshold = self .concept_embeds_weights [concept_idx ].item ()
94
- result_img ["concept_scores" ][concept_idx ] = round (concept_cos - concept_threshold + adjustment , 3 )
111
+ result_img ["concept_scores" ][concept_idx ] = round (concept_cos - concept_threshold + self . adjustment , 3 )
95
112
if result_img ["concept_scores" ][concept_idx ] > 0 :
96
113
result_img ["bad_concepts" ].append (concept_idx )
97
114
@@ -124,9 +141,9 @@ def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor):
124
141
125
142
# increase this value to create a stronger `nsfw` filter
126
143
# at the cost of increasing the possibility of filtering benign images
127
- adjustment = 0.0
144
+ # adjustment = 0.0
128
145
129
- special_scores = special_cos_dist - self .special_care_embeds_weights + adjustment
146
+ special_scores = special_cos_dist - self .special_care_embeds_weights + self . adjustment
130
147
# special_scores = special_scores.round(decimals=3)
131
148
special_care = torch .any (special_scores > 0 , dim = 1 )
132
149
special_adjustment = special_care * 0.01
0 commit comments