Skip to content

Commit 73056fb

Browse files
committed
update CLI default arguments with working ones
1 parent 6ace75c commit 73056fb

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

music_diffusion/__main__.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,25 +60,22 @@ def main() -> None:
6060
# Model hyper parameters
6161
model_parser = sub_command.add_parser("model")
6262

63-
model_parser.add_argument("--steps", type=int, default=1024)
64-
model_parser.add_argument("--beta-1", type=float, default=1e-4)
65-
model_parser.add_argument("--beta-t", type=float, default=2e-2)
63+
model_parser.add_argument("--steps", type=int, default=4096)
64+
model_parser.add_argument("--beta-1", type=float, default=2.5e-5)
65+
model_parser.add_argument("--beta-t", type=float, default=5e-3)
6666
model_parser.add_argument("--channels", type=int, default=2)
6767
model_parser.add_argument(
6868
"--unet-channels",
6969
type=_channels,
7070
default=[
71-
(8, 16),
72-
(16, 24),
73-
(24, 32),
74-
(32, 40),
75-
(40, 48),
76-
(48, 56),
77-
(56, 64),
71+
(16, 32),
72+
(32, 48),
73+
(48, 64),
74+
(64, 80),
7875
],
7976
)
80-
model_parser.add_argument("--time-size", type=int, default=32)
81-
model_parser.add_argument("--norm-groups", type=int, default=8)
77+
model_parser.add_argument("--time-size", type=int, default=48)
78+
model_parser.add_argument("--norm-groups", type=int, default=16)
8279
model_parser.add_argument("--cuda", action="store_true")
8380

8481
# Sub command run {train, generate}
@@ -92,7 +89,7 @@ def main() -> None:
9289
train_parser.add_argument("run_name", type=str)
9390

9491
train_parser.add_argument("-i", "--input-dataset", type=str, required=True)
95-
train_parser.add_argument("--batch-size", type=int, default=8)
92+
train_parser.add_argument("--batch-size", type=int, default=6)
9693
train_parser.add_argument("--step-batch-size", type=int, default=1)
9794
train_parser.add_argument("--epochs", type=int, default=1000)
9895
train_parser.add_argument("--learning-rate", type=float, default=2e-4)

0 commit comments

Comments
 (0)