1919from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
2020from transformers import AutoFeatureExtractor
2121
22- feature_extractor = AutoFeatureExtractor .from_pretrained ("CompVis/stable-diffusion-v-1-3" , use_auth_token = True )
23- safety_checker = StableDiffusionSafetyChecker .from_pretrained ("CompVis/stable-diffusion-v-1-3" , use_auth_token = True )
22+ # load safety model
23+ safety_model_id = "CompVis/stable-diffusion-v-1-3"
24+ safety_feature_extractor = AutoFeatureExtractor .from_pretrained (safety_model_id , use_auth_token = True )
25+ safety_checker = StableDiffusionSafetyChecker .from_pretrained (safety_model_id , use_auth_token = True )
2426
2527def chunk (it , size ):
2628 it = iter (it )
@@ -266,16 +268,23 @@ def main():
266268
267269 x_samples_ddim = model .decode_first_stage (samples_ddim )
268270 x_samples_ddim = torch .clamp ((x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
271+ x_samples_ddim = x_samples_ddim .cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
272+
273+ x_image = x_samples_ddim
274+ safety_checker_input = safety_feature_extractor (numpy_to_pil (x_image ), return_tensors = "pt" )
275+ x_checked_image , has_nsfw_concept = safety_checker (images = x_image , clip_input = safety_checker_input .pixel_values )
276+
277+ x_checked_image_torch = torch .from_numpy (x_checked_image ).permute (0 , 3 , 2 , 1 )
269278
270279 if not opt .skip_save :
271- for x_sample in x_samples_ddim :
280+ for x_sample in x_checked_image_torch :
272281 x_sample = 255. * rearrange (x_sample .cpu ().numpy (), 'c h w -> h w c' )
273282 Image .fromarray (x_sample .astype (np .uint8 )).save (
274283 os .path .join (sample_path , f"{ base_count :05} .png" ))
275284 base_count += 1
276285
277286 if not opt .skip_grid :
278- all_samples .append (x_samples_ddim )
287+ all_samples .append (x_checked_image_torch )
279288
280289 if not opt .skip_grid :
281290 # additionally, save as grid
@@ -288,12 +297,6 @@ def main():
288297 Image .fromarray (grid .astype (np .uint8 )).save (os .path .join (outpath , f'grid-{ grid_count :04} .png' ))
289298 grid_count += 1
290299
291- image = x_samples_ddim .cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
292-
293- # run safety checker
294- safety_checker_input = pipe .feature_extractor (numpy_to_pil (image ), return_tensors = "pt" )
295- image , has_nsfw_concept = pipe .safety_checker (images = image , clip_input = safety_checker_input .pixel_values )
296-
297300 print (f"Your samples are ready and waiting for you here: \n { outpath } \n "
298301 f" \n Enjoy." )
299302
0 commit comments