Skip to content

Commit a1d55e1

Browse files
authored
Change the default weighting_scheme in the SD3 scripts (#8639)
* change to logit_normal as the weighting scheme * sensible default mote
1 parent e5564d4 commit a1d55e1

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

examples/dreambooth/README_sd3.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ The `train_dreambooth_sd3.py` script shows how to implement the training procedu
1111
huggingface-cli login
1212
```
1313

14+
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
15+
1416
## Running locally with PyTorch
1517

1618
### Installing the dependencies
@@ -52,8 +54,6 @@ write_basic_config()
5254
```
5355

5456
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
55-
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
56-
5757

5858
### Dog toy example
5959

@@ -72,8 +72,6 @@ snapshot_download(
7272
)
7373
```
7474

75-
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
76-
7775
Now, we can launch training using:
7876

7977
```bash
@@ -116,6 +114,8 @@ To better track our training experiments, we're using the following flags in the
116114

117115
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
118116

117+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
118+
119119
To perform DreamBooth with LoRA, run:
120120

121121
```bash
@@ -142,3 +142,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
142142
--seed="0" \
143143
--push_to_hub
144144
```
145+
146+
## Other notes
147+
148+
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,10 @@ def parse_args(input_args=None):
477477
),
478478
)
479479
parser.add_argument(
480-
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
480+
"--weighting_scheme",
481+
type=str,
482+
default="logit_normal",
483+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
481484
)
482485
parser.add_argument(
483486
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,10 @@ def parse_args(input_args=None):
472472
),
473473
)
474474
parser.add_argument(
475-
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
475+
"--weighting_scheme",
476+
type=str,
477+
default="logit_normal",
478+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
476479
)
477480
parser.add_argument(
478481
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."

0 commit comments

Comments
 (0)