@@ -418,6 +418,11 @@ def parse_args():
418418 default = 4 ,
419419 help = ("The dimension of the LoRA update matrices." ),
420420 )
421+ parser .add_argument (
422+ "--disable_safety_checker" ,
423+ action = "store_true" ,
424+ help = ("Disable the safety checker in the pipeline. Use this flag to allow full image output without filtering." ),
425+ )
421426
422427 args = parser .parse_args ()
423428 env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -923,6 +928,11 @@ def collate_fn(examples):
923928 variant = args .variant ,
924929 torch_dtype = weight_dtype ,
925930 )
931+
932+ # Disable the safety checker if the flag is present
933+ if args .disable_safety_checker :
934+ pipeline .safety_checker = None
935+
926936 images = log_validation (pipeline , args , accelerator , epoch )
927937
928938 del pipeline
@@ -954,6 +964,10 @@ def collate_fn(examples):
954964 # load attention processors
955965 pipeline .load_lora_weights (args .output_dir )
956966
967+ # Disable the safety checker if the flag is present
968+ if args .disable_safety_checker :
969+ pipeline .safety_checker = None
970+
957971 # run inference
958972 images = log_validation (pipeline , args , accelerator , epoch , is_final_validation = True )
959973
0 commit comments