Skip to content

Commit c47807b

Browse files
author
Ubuntu
committed
change int to float for lambda values
1 parent 1824e78 commit c47807b

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/aging_gan/train.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def parse_args() -> argparse.Namespace:
3434
p.add_argument(
3535
"--gen_lr",
3636
type=float,
37-
default=1e-4,
37+
default=2e-4,
3838
help="Initial learning rate for generators.",
3939
)
4040
p.add_argument(
@@ -49,31 +49,31 @@ def parse_args() -> argparse.Namespace:
4949
p.add_argument(
5050
"--train_batch_size",
5151
type=int,
52-
default=4,
52+
default=16,
5353
help="Batch size per device during training.",
5454
)
5555
p.add_argument(
5656
"--eval_batch_size",
5757
type=int,
58-
default=8,
58+
default=32,
5959
help="Batch size per device during evaluation.",
6060
)
6161
p.add_argument(
6262
"--lambda_adv_value",
63-
type=int,
64-
default=2,
63+
type=float,
64+
default=2.0,
6565
help="Weight for adversarial loss",
6666
)
6767
p.add_argument(
6868
"--lambda_cyc_value",
69-
type=int,
70-
default=10,
69+
type=float,
70+
default=4.0,
7171
help="Weight for cyclical loss",
7272
)
7373
p.add_argument(
7474
"--lambda_id_value",
75-
type=int,
76-
default=7,
75+
type=float,
76+
default=0.5,
7777
help="Weight for identity loss",
7878
)
7979
p.add_argument(
@@ -160,7 +160,7 @@ def initialize_optimizers(cfg, G, F, DX, DY):
160160

161161

162162
def initialize_loss_functions(
163-
lambda_adv_value: int = 2, lambda_cyc_value: int = 10, lambda_id_value: int = 7
163+
lambda_adv_value: float = 2.0, lambda_cyc_value: float = 10.0, lambda_id_value: float = 7.0
164164
):
165165
mse = nn.MSELoss()
166166
l1 = nn.L1Loss()

0 commit comments

Comments
 (0)