diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 66b56740ef13..0e3812938fe0 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -2073,3 +2073,38 @@ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): else: self.vae.unfuse_qkv_projections() self.fusing_vae = False + + def safety_checker_level(self, level): + """ + Adjust the safety checker level. + + Args: + Level (`int` or `float` or one of the following [`WEAK`], [`MEDIUM`], [`NOMAL`], [`STRONG`], [`MAX`]): + The level of safety checker adjustment, either as an integer, a float, or one of the predefined levels. + Negative values decrease the filtering strength, while positive values increase it. + """ + # Retrieve the safety_checker attribute from the instance + _safety_checker = getattr(self, "safety_checker", None) + + # Check if the safety_checker exists + if _safety_checker is not None: + # Check if the safety_checker has the update_safety_checker_Level method + if hasattr(_safety_checker, "update_safety_checker_Level"): + # Update the safety checker level using the provided method + self.safety_checker.update_safety_checker_Level(level) + else: + # Log a warning if the method is not found in safety_checker + logger.warning( + "`safety_checker_level` is ignored because `update_safety_checker_Level` is not in `safety_checker`." + ) + else: + # Log a warning if safety_checker is not present + logger.warning("Since there is no `safety_checker`, `safety_checker_level` is ignored.") + + @property + def filter_level(self): + """ + Return: + `int` ,`float` or None + """ + return getattr(getattr(self, "safety_checker", None), "adjustment", None) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 3a0e86409e4a..5c4e606a4af4 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -47,6 +47,50 @@ def __init__(self, config: CLIPConfig): self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + self.adjustment = 0.0 + + def update_safety_checker_Level(self, Level): + """ + Adjust the safety checker level. + + Parameters: + Level (`int` or `float` or one of the following [`WEAK`], [`MEDIUM`], [`NOMAL`], [`STRONG`], [`MAX`]): + The level of safety checker adjustment, either as an integer, a float, or one of the predefined levels. + Negative values decrease the filtering strength, while positive values increase it. + """ + Level_dict = { + "WEAK": -0.0690, + "MEDIUM": -0.0175, + "NOMAL": 0.0, + "STRONG": 0.0150, + "MAX": 0.0740, + } + + # If the provided Level is a predefined string, convert it to the corresponding float value + if Level in Level_dict: + Level = Level_dict[Level] + + # Check if the Level is a float or an integer + if isinstance(Level, (float, int)): + setattr(self, "adjustment", Level) # Set the adjustment attribute to the Level value + else: + # Raise an error if Level is not a valid type or predefined string + raise ValueError( + "`int` or `float` or one of the following ['WEAK'], ['MEDIUM'], ['NOMAL'], ['STRONG'], ['MAX']" + ) + + # Log a warning if the adjustment level is weakened (negative value) + if self.adjustment < 0: + logger.warning( + "You have weakened the filtering strength of safety checker. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " When reducing the filtering strength, take the same action as when disabling the safety checker." + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + @torch.no_grad() def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output @@ -63,20 +107,19 @@ def forward(self, clip_input, images): # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign images - adjustment = 0.0 for concept_idx in range(len(special_cos_dist[0])): concept_cos = special_cos_dist[i][concept_idx] concept_threshold = self.special_care_embeds_weights[concept_idx].item() - result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + self.adjustment, 3) if result_img["special_scores"][concept_idx] > 0: result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) - adjustment = 0.01 + self.adjustment += 0.01 for concept_idx in range(len(cos_dist[0])): concept_cos = cos_dist[i][concept_idx] concept_threshold = self.concept_embeds_weights[concept_idx].item() - result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + self.adjustment, 3) if result_img["concept_scores"][concept_idx] > 0: result_img["bad_concepts"].append(concept_idx) @@ -109,9 +152,8 @@ def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images - adjustment = 0.0 - special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + special_scores = special_cos_dist - self.special_care_embeds_weights + self.adjustment # special_scores = special_scores.round(decimals=3) special_care = torch.any(special_scores > 0, dim=1) special_adjustment = special_care * 0.01