Skip to content

Commit fca393a

Browse files
committed
👊 Change adam->radam and use ExponentialDecay Learning rate.
1 parent 09026c7 commit fca393a

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,19 @@ is_shuffle: true # shuffle dataset after each epoch.
7676
# OPTIMIZER & SCHEDULER SETTING #
7777
###########################################################
7878
generator_optimizer_params:
79-
lr_fn: "PiecewiseConstantDecay"
79+
lr_fn: "ExponentialDecay"
8080
lr_params:
81-
boundaries: [100000] # = discriminator_train_start_steps.
82-
values: [0.0005, 0.0001] # learning rate each interval.
81+
initial_learning_rate: 0.0005
82+
decay_steps: 200000
83+
decay_rate: 0.5
8384

8485

8586
discriminator_optimizer_params:
86-
lr_fn: "PiecewiseConstantDecay"
87+
lr_fn: "ExponentialDecay"
8788
lr_params:
88-
boundaries: [0] # after resume and start training discriminator, global steps is 100k, but local discriminator step is 0
89-
values: [0.0001, 0.0001] # learning rate each interval.
90-
89+
initial_learning_rate: 0.0005
90+
decay_steps: 200000
91+
decay_rate: 0.5
9192

9293
###########################################################
9394
# INTERVAL SETTING #

examples/parallel_wavegan/train_parallel_wavegan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
from tensorflow_tts.losses import TFMultiResolutionSTFT
5151
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
5252

53+
from tensorflow_addons.optimizers import RectifiedAdam
54+
5355

5456
class ParallelWaveganTrainer(GanBasedTrainer):
5557
"""ParallelWaveGAN Trainer class based on GanBasedTrainer."""
@@ -439,10 +441,10 @@ def main():
439441
config["discriminator_optimizer_params"]["lr_fn"],
440442
)(**config["discriminator_optimizer_params"]["lr_params"])
441443

442-
gen_optimizer = tf.keras.optimizers.Adam(
444+
gen_optimizer = RectifiedAdam(
443445
learning_rate=generator_lr_fn, amsgrad=False
444446
)
445-
dis_optimizer = tf.keras.optimizers.Adam(
447+
dis_optimizer = RectifiedAdam(
446448
learning_rate=discriminator_lr_fn, amsgrad=False
447449
)
448450

0 commit comments

Comments
 (0)