1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import os
15+ import zipfile
1516from typing import Optional
1617
1718import requests
2324
2425
2526# MLP code taken from LAION's CLIP-based-NSFW-Detector
26- # https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/h14_nsfw_model.py
27- class H14_NSFW_Detector (nn .Module ):
28- def __init__ (self , input_size = 1024 ):
27+ # https://github.com/LAION-AI/CLIP-based-NSFW-Detector/issues/7
28+ class Normalization (nn .Module ):
29+ def __init__ (self , shape ):
2930 super ().__init__ ()
30- self .input_size = input_size
31- self .layers = nn .Sequential (
32- nn .Linear (self .input_size , 1024 ),
33- nn .ReLU (),
34- nn .Dropout (0.2 ),
35- nn .Linear (1024 , 2048 ),
36- nn .ReLU (),
37- nn .Dropout (0.2 ),
38- nn .Linear (2048 , 1024 ),
39- nn .ReLU (),
40- nn .Dropout (0.2 ),
41- nn .Linear (1024 , 256 ),
42- nn .ReLU (),
43- nn .Dropout (0.2 ),
44- nn .Linear (256 , 128 ),
45- nn .ReLU (),
46- nn .Dropout (0.2 ),
47- nn .Linear (128 , 16 ),
48- nn .Linear (16 , 1 ),
49- )
31+ self .register_buffer ("mean" , torch .zeros (shape ))
32+ self .register_buffer ("variance" , torch .ones (shape ))
33+
34+ def forward (self , x ):
35+ return (x - self .mean ) / self .variance .sqrt ()
36+
37+
38+ class NSFWModel (nn .Module ):
39+ def __init__ (self ):
40+ super ().__init__ ()
41+ self .norm = Normalization ([768 ])
42+ self .linear_1 = nn .Linear (768 , 64 )
43+ self .linear_2 = nn .Linear (64 , 512 )
44+ self .linear_3 = nn .Linear (512 , 256 )
45+ self .linear_4 = nn .Linear (256 , 1 )
46+ self .act = nn .ReLU ()
47+ self .act_out = nn .Sigmoid ()
5048
5149 def forward (self , x ):
52- return self .layers (x )
50+ x = self .norm (x )
51+ x = self .act (self .linear_1 (x ))
52+ x = self .act (self .linear_2 (x ))
53+ x = self .act (self .linear_3 (x ))
54+ x = self .act_out (self .linear_4 (x ))
55+ return x
5356
5457
5558class NsfwClassifier (ImageClassifier ):
@@ -66,7 +69,7 @@ def __init__(
6669 pred_column = pred_column ,
6770 pred_type = float ,
6871 batch_size = batch_size ,
69- embedding_size = 1024 ,
72+ embedding_size = 768 ,
7073 )
7174
7275 if model_path is None :
@@ -76,21 +79,24 @@ def __init__(
7679
7780 @staticmethod
7881 def _get_default_model ():
79- weights_name = "h14_nsfw .pth"
82+ weights_name = "clip_autokeras_binary_nsfw .pth"
8083 model_path = os .path .join (NEMO_CURATOR_HOME , weights_name )
8184 os .makedirs (NEMO_CURATOR_HOME , exist_ok = True )
8285
8386 if not os .path .exists (model_path ):
84- url = f "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/ { weights_name } ?raw=true "
87+ url = "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/files/10250461/clip_autokeras_binary_nsfw.zip "
8588 r = requests .get (url )
8689
87- with open (model_path , "wb" ) as f :
90+ raw_zip_path = os .path .join (NEMO_CURATOR_HOME , "nsfw.zip" )
91+ with open (raw_zip_path , "wb" ) as f :
8892 f .write (r .content )
93+ with zipfile .ZipFile (raw_zip_path , "r" ) as f :
94+ f .extractall (NEMO_CURATOR_HOME )
8995
9096 return model_path
9197
9298 def load_model (self , device ):
93- model = H14_NSFW_Detector ( input_size = self . embedding_size ).to (device )
99+ model = NSFWModel ( ).to (device )
94100 weights = torch .load (self .model_path , map_location = torch .device ("cpu" ))
95101 model .load_state_dict (weights )
96102 model .eval ()
0 commit comments