Skip to content

Commit 9b39abc

Browse files
committed
[training] Convert to ImageFolder script
1 parent 4f3ec53 commit 9b39abc

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

examples/dreambooth/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \
742742
## Stable Diffusion XL
743743

744744
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
745+
746+
## Dataset
747+
748+
We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
749+
750+
The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
751+
752+
We need to create a file `metadata.jsonl` in the directory with our images:
753+
754+
```
755+
{"file_name": "01.jpg", "prompt": "prompt 01"}
756+
{"file_name": "02.jpg", "prompt": "prompt 02"}
757+
```
758+
759+
If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
760+
761+
```sh
762+
python convert_to_imagefolder.py --path my_dataset/
763+
```
764+
765+
We use `--dataset_name` and `--caption_column` with training scripts.
766+
767+
```
768+
--dataset_name=my_dataset/
769+
--caption_column=prompt
770+
```
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import argparse
2+
import json
3+
import pathlib
4+
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument(
7+
"--path",
8+
type=str,
9+
required=True,
10+
help="Path to folder with image-text pairs.",
11+
)
12+
parser.add_argument(
13+
"--caption_column", type=str, default="prompt", help="Name of caption column."
14+
)
15+
args = parser.parse_args()
16+
17+
path = pathlib.Path(args.path)
18+
if not path.exists():
19+
raise RuntimeError(f"`--path` '{args.path}' does not exist.")
20+
21+
all_files = list(path.glob("*"))
22+
captions = list(path.glob("*.txt"))
23+
images = set(all_files) - set(captions)
24+
images = {image.stem: image for image in images}
25+
caption_image = {
26+
caption: images.get(caption.stem)
27+
for caption in captions
28+
if images.get(caption.stem)
29+
}
30+
31+
metadata = path.joinpath("metadata.jsonl")
32+
33+
with metadata.open("w", encoding="utf-8") as f:
34+
for caption, image in caption_image.items():
35+
caption_text = caption.read_text(encoding="utf-8")
36+
json.dump({"file_name": image.name, args.caption_column: caption_text}, f)
37+
f.write("\n")

0 commit comments

Comments
 (0)