Skip to content

Commit 074a7cc

Browse files
bghirabghirakashif
authored
SD3: update default training timestep / loss weighting distribution to logit_normal (#8592)
Co-authored-by: bghira <[email protected]> Co-authored-by: Kashif Rasul <[email protected]>
1 parent 6bfd13f commit 074a7cc

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def parse_args(input_args=None):
473473
),
474474
)
475475
parser.add_argument(
476-
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"]
476+
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"]
477477
)
478478
parser.add_argument("--logit_mean", type=float, default=0.0)
479479
parser.add_argument("--logit_std", type=float, default=1.0)

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def parse_args(input_args=None):
471471
),
472472
)
473473
parser.add_argument(
474-
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"]
474+
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"]
475475
)
476476
parser.add_argument("--logit_mean", type=float, default=0.0)
477477
parser.add_argument("--logit_std", type=float, default=1.0)

0 commit comments

Comments
 (0)