Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
required=True,
required=False,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
Expand Down Expand Up @@ -714,7 +714,7 @@ def __init__(
# we load the training data using load_dataset
if args.dataset_name is not None:
try:
from datasets import load_dataset
from datasets import load_dataset, Image
except ImportError:
raise ImportError(
"You are trying to load your data using the datasets library. If you wish to train using custom "
Expand Down Expand Up @@ -742,6 +742,7 @@ def __init__(
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
dataset["train"] = dataset["train"].cast_column(image_column, Image(decode=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

@davidberenstein1957 davidberenstein1957 Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul fix to force casting and image column to decode to ensure this is done for feature sets where decoding automatically is not the default

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's have that as a comment too. Thanks for explaining!

instance_images = dataset["train"][image_column]

if args.caption_column is None:
Expand All @@ -768,7 +769,6 @@ def __init__(

instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
self.custom_instance_prompts = None

self.instance_images = []
for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats))
Expand Down
Loading