@@ -60,25 +60,22 @@ def main() -> None:
6060 # Model hyper parameters
6161 model_parser = sub_command .add_parser ("model" )
6262
63- model_parser .add_argument ("--steps" , type = int , default = 1024 )
64- model_parser .add_argument ("--beta-1" , type = float , default = 1e-4 )
65- model_parser .add_argument ("--beta-t" , type = float , default = 2e-2 )
63+ model_parser .add_argument ("--steps" , type = int , default = 4096 )
64+ model_parser .add_argument ("--beta-1" , type = float , default = 2.5e-5 )
65+ model_parser .add_argument ("--beta-t" , type = float , default = 5e-3 )
6666 model_parser .add_argument ("--channels" , type = int , default = 2 )
6767 model_parser .add_argument (
6868 "--unet-channels" ,
6969 type = _channels ,
7070 default = [
71- (8 , 16 ),
72- (16 , 24 ),
73- (24 , 32 ),
74- (32 , 40 ),
75- (40 , 48 ),
76- (48 , 56 ),
77- (56 , 64 ),
71+ (16 , 32 ),
72+ (32 , 48 ),
73+ (48 , 64 ),
74+ (64 , 80 ),
7875 ],
7976 )
80- model_parser .add_argument ("--time-size" , type = int , default = 32 )
81- model_parser .add_argument ("--norm-groups" , type = int , default = 8 )
77+ model_parser .add_argument ("--time-size" , type = int , default = 48 )
78+ model_parser .add_argument ("--norm-groups" , type = int , default = 16 )
8279 model_parser .add_argument ("--cuda" , action = "store_true" )
8380
8481 # Sub command run {train, generate}
@@ -92,7 +89,7 @@ def main() -> None:
9289 train_parser .add_argument ("run_name" , type = str )
9390
9491 train_parser .add_argument ("-i" , "--input-dataset" , type = str , required = True )
95- train_parser .add_argument ("--batch-size" , type = int , default = 8 )
92+ train_parser .add_argument ("--batch-size" , type = int , default = 6 )
9693 train_parser .add_argument ("--step-batch-size" , type = int , default = 1 )
9794 train_parser .add_argument ("--epochs" , type = int , default = 1000 )
9895 train_parser .add_argument ("--learning-rate" , type = float , default = 2e-4 )
0 commit comments