diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ed9a6453f038..9db0c62b8186 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -418,6 +418,11 @@ def parse_args(): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--disable_safety_checker", + action="store_true", + help=("Disable the safety checker in the pipeline. Use this flag to allow full image output without filtering."), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -923,6 +928,11 @@ def collate_fn(examples): variant=args.variant, torch_dtype=weight_dtype, ) + + # Disable the safety checker if the flag is present + if args.disable_safety_checker: + pipeline.safety_checker = None + images = log_validation(pipeline, args, accelerator, epoch) del pipeline @@ -954,6 +964,10 @@ def collate_fn(examples): # load attention processors pipeline.load_lora_weights(args.output_dir) + # Disable the safety checker if the flag is present + if args.disable_safety_checker: + pipeline.safety_checker = None + # run inference images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)