Skip to content

Commit 2c1cbb3

Browse files
fix required instance_column and image_column cast
1 parent c9e4fab commit 2c1cbb3

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def parse_args(input_args=None):
308308
"--instance_prompt",
309309
type=str,
310310
default=None,
311-
required=True,
311+
required=False,
312312
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
313313
)
314314
parser.add_argument(
@@ -714,7 +714,7 @@ def __init__(
714714
# we load the training data using load_dataset
715715
if args.dataset_name is not None:
716716
try:
717-
from datasets import load_dataset
717+
from datasets import load_dataset, Image
718718
except ImportError:
719719
raise ImportError(
720720
"You are trying to load your data using the datasets library. If you wish to train using custom "
@@ -742,6 +742,8 @@ def __init__(
742742
raise ValueError(
743743
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
744744
)
745+
746+
dataset["train"] = dataset["train"].cast_column(image_column, Image(decode=True))
745747
instance_images = dataset["train"][image_column]
746748

747749
if args.caption_column is None:
@@ -768,7 +770,7 @@ def __init__(
768770

769771
instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
770772
self.custom_instance_prompts = None
771-
773+
772774
self.instance_images = []
773775
for img in instance_images:
774776
self.instance_images.extend(itertools.repeat(img, repeats))
@@ -784,6 +786,7 @@ def __init__(
784786
]
785787
)
786788
for image in self.instance_images:
789+
787790
image = exif_transpose(image)
788791
if not image.mode == "RGB":
789792
image = image.convert("RGB")

0 commit comments

Comments
 (0)