@@ -34,28 +34,28 @@ def parse_args() -> argparse.Namespace:
3434 p .add_argument (
3535 "--gen_lr" ,
3636 type = float ,
37- default = 3e -4 ,
37+ default = 1e -4 ,
3838 help = "Initial learning rate for generators." ,
3939 )
4040 p .add_argument (
4141 "--disc_lr" ,
4242 type = float ,
43- default = 2e -4 ,
43+ default = 1e -4 ,
4444 help = "Initial learning rate for discriminators." ,
4545 )
4646 p .add_argument (
47- "--num_train_epochs" , type = int , default = 25 , help = "Number of training epochs."
47+ "--num_train_epochs" , type = int , default = 100 , help = "Number of training epochs."
4848 )
4949 p .add_argument (
5050 "--train_batch_size" ,
5151 type = int ,
52- default = 16 ,
52+ default = 4 ,
5353 help = "Batch size per device during training." ,
5454 )
5555 p .add_argument (
5656 "--eval_batch_size" ,
5757 type = int ,
58- default = 32 ,
58+ default = 8 ,
5959 help = "Batch size per device during evaluation." ,
6060 )
6161 p .add_argument (
@@ -67,7 +67,7 @@ def parse_args() -> argparse.Namespace:
6767 p .add_argument (
6868 "--lambda_cyc_value" ,
6969 type = int ,
70- default = 7 ,
70+ default = 10 ,
7171 help = "Weight for cyclical loss" ,
7272 )
7373 p .add_argument (
@@ -164,7 +164,7 @@ def initialize_optimizers(cfg, G, F, DX, DY):
164164
165165
166166def initialize_loss_functions (
167- lambda_adv_value : int = 2 , lambda_cyc_value : int = 10 , lambda_id_value : int = 5
167+ lambda_adv_value : int = 2 , lambda_cyc_value : int = 10 , lambda_id_value : int = 7
168168):
169169 mse = nn .MSELoss ()
170170 l1 = nn .L1Loss ()
@@ -643,22 +643,23 @@ def main() -> None:
643643 "best" ,
644644 )
645645 # save the latest checkpoint
646- save_checkpoint (
647- epoch ,
648- G ,
649- F ,
650- DX ,
651- DY ,
652- opt_G ,
653- opt_F , # generator optimizers
654- opt_DX ,
655- opt_DY , # discriminator optimizers
656- sched_G ,
657- sched_F ,
658- sched_DX ,
659- sched_DY , # schedulers
660- "latest" ,
661- )
646+ if epoch % 5 == 0 :
647+ save_checkpoint (
648+ epoch ,
649+ G ,
650+ F ,
651+ DX ,
652+ DY ,
653+ opt_G ,
654+ opt_F , # generator optimizers
655+ opt_DX ,
656+ opt_DY , # discriminator optimizers
657+ sched_G ,
658+ sched_F ,
659+ sched_DX ,
660+ sched_DY , # schedulers
661+ "current" ,
662+ )
662663
663664 # ---------- Test ----------
664665 if cfg .do_test :
0 commit comments