Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \
## Stable Diffusion XL

We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).

## Dataset

We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.

The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).

We need to create a file `metadata.jsonl` in the directory with our images:

```
{"file_name": "01.jpg", "prompt": "prompt 01"}
{"file_name": "02.jpg", "prompt": "prompt 02"}
```

If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.

```sh
python convert_to_imagefolder.py --path my_dataset/
```

We use `--dataset_name` and `--caption_column` with training scripts.

```
--dataset_name=my_dataset/
--caption_column=prompt
```
32 changes: 32 additions & 0 deletions examples/dreambooth/convert_to_imagefolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import argparse
import json
import pathlib


parser = argparse.ArgumentParser()
parser.add_argument(
"--path",
type=str,
required=True,
help="Path to folder with image-text pairs.",
)
parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.")
args = parser.parse_args()

path = pathlib.Path(args.path)
if not path.exists():
raise RuntimeError(f"`--path` '{args.path}' does not exist.")

all_files = list(path.glob("*"))
captions = list(path.glob("*.txt"))
images = set(all_files) - set(captions)
images = {image.stem: image for image in images}
caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)}

metadata = path.joinpath("metadata.jsonl")

with metadata.open("w", encoding="utf-8") as f:
for caption, image in caption_image.items():
caption_text = caption.read_text(encoding="utf-8")
json.dump({"file_name": image.name, args.caption_column: caption_text}, f)
f.write("\n")