Skip to content

Commit 555793f

Browse files
Fix: Add option to disable safety_checker to prevent black images
1 parent 811560b commit 555793f

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,11 @@ def parse_args():
418418
default=4,
419419
help=("The dimension of the LoRA update matrices."),
420420
)
421+
parser.add_argument(
422+
"--disable_safety_checker",
423+
action="store_true",
424+
help=("Disable the safety checker in the pipeline. Use this flag to allow full image output without filtering."),
425+
)
421426

422427
args = parser.parse_args()
423428
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -923,6 +928,11 @@ def collate_fn(examples):
923928
variant=args.variant,
924929
torch_dtype=weight_dtype,
925930
)
931+
932+
# Disable the safety checker if the flag is present
933+
if args.disable_safety_checker:
934+
pipeline.safety_checker = None
935+
926936
images = log_validation(pipeline, args, accelerator, epoch)
927937

928938
del pipeline
@@ -954,6 +964,10 @@ def collate_fn(examples):
954964
# load attention processors
955965
pipeline.load_lora_weights(args.output_dir)
956966

967+
# Disable the safety checker if the flag is present
968+
if args.disable_safety_checker:
969+
pipeline.safety_checker = None
970+
957971
# run inference
958972
images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)
959973

0 commit comments

Comments
 (0)