@@ -17,6 +17,9 @@ def __init__(self):
1717 self .conv_moderation_responses : List [
1818 Dict [str , Dict [str , Union [str , Dict [str , float ]]]]
1919 ] = []
20+ self .text_flagged = False
21+ self .csam_flagged = False
22+ self .nsfw_flagged = False
2023
2124 def _image_moderation_filter (self , image : Image ) -> Tuple [bool , bool ]:
2225 """Function that detects whether image violates moderation policies.
@@ -34,6 +37,11 @@ def _text_moderation_filter(self, text: str) -> bool:
3437 """
3538 raise NotImplementedError
3639
40+ def reset_moderation_flags (self ):
41+ self .text_flagged = False
42+ self .csam_flagged = False
43+ self .nsfw_flagged = False
44+
3745 def image_and_text_moderation_filter (
3846 self , image : Image , text : str
3947 ) -> Dict [str , Dict [str , Union [str , Dict [str , float ]]]]:
@@ -77,9 +85,6 @@ def __init__(self, use_remote_storage: bool = False):
7785 to the moderation API.
7886 """
7987 super ().__init__ ()
80- self .text_flagged = False
81- self .csam_flagged = False
82- self .nsfw_flagged = False
8388
8489 def _image_moderation_request (
8590 self , image_bytes : bytes , endpoint : str , api_key : str
@@ -253,6 +258,7 @@ def image_and_text_moderation_filter(
253258 }
254259 """
255260 print ("moderating text: " , text )
261+ self .reset_moderation_flags ()
256262 text_flagged_map = self .text_moderation_filter (text , model_list , do_moderation )
257263
258264 if image is not None :
0 commit comments