diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index f73269a48967..f4ee1231b402 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -308,7 +308,7 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, + required=False, help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) parser.add_argument( @@ -714,7 +714,7 @@ def __init__( # we load the training data using load_dataset if args.dataset_name is not None: try: - from datasets import load_dataset + from datasets import load_dataset, Image except ImportError: raise ImportError( "You are trying to load your data using the datasets library. If you wish to train using custom " @@ -742,6 +742,8 @@ def __init__( raise ValueError( f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) + # force casting for an image column to decode when this isn't the default + dataset["train"] = dataset["train"].cast_column(image_column, Image(decode=True)) instance_images = dataset["train"][image_column] if args.caption_column is None: