Skip to content

Commit 3fd7a77

Browse files
committed
add typehit for updater and evaluator, test=tts
1 parent 9c7f076 commit 3fd7a77

File tree

4 files changed

+48
-34
lines changed

4 files changed

+48
-34
lines changed

paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
from pathlib import Path
1516

1617
from paddle import distributed as dist
18+
from paddle.io import DataLoader
19+
from paddle.nn import Layer
20+
from paddle.optimizer import Optimizer
1721

1822
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
1923
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
@@ -28,13 +32,13 @@
2832

2933
class FastSpeech2Updater(StandardUpdater):
3034
def __init__(self,
31-
model,
32-
optimizer,
33-
dataloader,
35+
model: Layer,
36+
optimizer: Optimizer,
37+
dataloader: DataLoader,
3438
init_state=None,
35-
use_masking=False,
36-
use_weighted_masking=False,
37-
output_dir=None):
39+
use_masking: bool=False,
40+
use_weighted_masking: bool=False,
41+
output_dir: Path=None):
3842
super().__init__(model, optimizer, dataloader, init_state=None)
3943

4044
self.criterion = FastSpeech2Loss(
@@ -104,11 +108,11 @@ def update_core(self, batch):
104108

105109
class FastSpeech2Evaluator(StandardEvaluator):
106110
def __init__(self,
107-
model,
108-
dataloader,
109-
use_masking=False,
110-
use_weighted_masking=False,
111-
output_dir=None):
111+
model: Layer,
112+
dataloader: DataLoader,
113+
use_masking: bool=False,
114+
use_weighted_masking: bool=False,
115+
output_dir: Path=None):
112116
super().__init__(model, dataloader)
113117

114118
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())

paddlespeech/t2s/models/new_tacotron2/tacotron2_updater.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import logging
1515
from pathlib import Path
16-
from typing import Dict
1716

1817
from paddle import distributed as dist
1918
from paddle.io import DataLoader
@@ -34,8 +33,8 @@
3433

3534
class Tacotron2Updater(StandardUpdater):
3635
def __init__(self,
37-
model: Dict[str, Layer],
38-
optimizer: Dict[str, Optimizer],
36+
model: Layer,
37+
optimizer: Optimizer,
3938
dataloader: DataLoader,
4039
init_state=None,
4140
use_masking: bool=True,
@@ -126,8 +125,8 @@ def update_core(self, batch):
126125

127126
class Tacotron2Evaluator(StandardEvaluator):
128127
def __init__(self,
129-
model,
130-
dataloader,
128+
model: Layer,
129+
dataloader: DataLoader,
131130
use_masking: bool=True,
132131
use_weighted_masking: bool=False,
133132
bce_pos_weight: float=5.0,

paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
from pathlib import Path
1516

1617
import paddle
1718
from paddle import distributed as dist
1819
from paddle.fluid.layers import huber_loss
20+
from paddle.io import DataLoader
1921
from paddle.nn import functional as F
22+
from paddle.nn import Layer
23+
from paddle.optimizer import Optimizer
2024

2125
from paddlespeech.t2s.modules.losses import masked_l1_loss
2226
from paddlespeech.t2s.modules.losses import ssim
@@ -33,11 +37,11 @@
3337

3438
class SpeedySpeechUpdater(StandardUpdater):
3539
def __init__(self,
36-
model,
37-
optimizer,
38-
dataloader,
40+
model: Layer,
41+
optimizer: Optimizer,
42+
dataloader: DataLoader,
3943
init_state=None,
40-
output_dir=None):
44+
output_dir: Path=None):
4145
super().__init__(model, optimizer, dataloader, init_state=None)
4246

4347
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
@@ -103,7 +107,10 @@ def update_core(self, batch):
103107

104108

105109
class SpeedySpeechEvaluator(StandardEvaluator):
106-
def __init__(self, model, dataloader, output_dir=None):
110+
def __init__(self,
111+
model: Layer,
112+
dataloader: DataLoader,
113+
output_dir: Path=None):
107114
super().__init__(model, dataloader)
108115

109116
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())

paddlespeech/t2s/models/transformer_tts/transformer_tts_updater.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
from pathlib import Path
1516
from typing import Sequence
1617

1718
import paddle
1819
from paddle import distributed as dist
20+
from paddle.io import DataLoader
21+
from paddle.nn import Layer
22+
from paddle.optimizer import Optimizer
1923

2024
from paddlespeech.t2s.modules.losses import GuidedMultiHeadAttentionLoss
2125
from paddlespeech.t2s.modules.losses import Tacotron2Loss as TransformerTTSLoss
@@ -32,14 +36,14 @@
3236
class TransformerTTSUpdater(StandardUpdater):
3337
def __init__(
3438
self,
35-
model,
36-
optimizer,
37-
dataloader,
39+
model: Layer,
40+
optimizer: Optimizer,
41+
dataloader: DataLoader,
3842
init_state=None,
39-
use_masking=False,
40-
use_weighted_masking=False,
41-
output_dir=None,
42-
bce_pos_weight=5.0,
43+
use_masking: bool=False,
44+
use_weighted_masking: bool=False,
45+
output_dir: Path=None,
46+
bce_pos_weight: float=5.0,
4347
loss_type: str="L1",
4448
use_guided_attn_loss: bool=True,
4549
modules_applied_guided_attn: Sequence[str]=("encoder-decoder"),
@@ -185,13 +189,13 @@ def update_core(self, batch):
185189
class TransformerTTSEvaluator(StandardEvaluator):
186190
def __init__(
187191
self,
188-
model,
189-
dataloader,
192+
model: Layer,
193+
dataloader: DataLoader,
190194
init_state=None,
191-
use_masking=False,
192-
use_weighted_masking=False,
193-
output_dir=None,
194-
bce_pos_weight=5.0,
195+
use_masking: bool=False,
196+
use_weighted_masking: bool=False,
197+
output_dir: Path=None,
198+
bce_pos_weight: float=5.0,
195199
loss_type: str="L1",
196200
use_guided_attn_loss: bool=True,
197201
modules_applied_guided_attn: Sequence[str]=("encoder-decoder"),

0 commit comments

Comments
 (0)