diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index f97a4d0cd0f4..eed0575c322d 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \ ## Stable Diffusion XL We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md). + +## Dataset + +We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own. + +The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder). + +We need to create a file `metadata.jsonl` in the directory with our images: + +``` +{"file_name": "01.jpg", "prompt": "prompt 01"} +{"file_name": "02.jpg", "prompt": "prompt 02"} +``` + +If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`. + +```sh +python convert_to_imagefolder.py --path my_dataset/ +``` + +We use `--dataset_name` and `--caption_column` with training scripts. + +``` +--dataset_name=my_dataset/ +--caption_column=prompt +``` diff --git a/examples/dreambooth/convert_to_imagefolder.py b/examples/dreambooth/convert_to_imagefolder.py new file mode 100644 index 000000000000..333080077428 --- /dev/null +++ b/examples/dreambooth/convert_to_imagefolder.py @@ -0,0 +1,32 @@ +import argparse +import json +import pathlib + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--path", + type=str, + required=True, + help="Path to folder with image-text pairs.", +) +parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.") +args = parser.parse_args() + +path = pathlib.Path(args.path) +if not path.exists(): + raise RuntimeError(f"`--path` '{args.path}' does not exist.") + +all_files = list(path.glob("*")) +captions = list(path.glob("*.txt")) +images = set(all_files) - set(captions) +images = {image.stem: image for image in images} +caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)} + +metadata = path.joinpath("metadata.jsonl") + +with metadata.open("w", encoding="utf-8") as f: + for caption, image in caption_image.items(): + caption_text = caption.read_text(encoding="utf-8") + json.dump({"file_name": image.name, args.caption_column: caption_text}, f) + f.write("\n")