Skip to content

Commit 53497a5

Browse files
committed
🚀 add trainable option for fine tuning
1 parent 4eeec43 commit 53497a5

File tree

7 files changed

+56
-9
lines changed

7 files changed

+56
-9
lines changed

tensorflow_asr/models/conformer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def __init__(self,
378378
encoder_depth_multiplier: int = 1,
379379
encoder_fc_factor: float = 0.5,
380380
encoder_dropout: float = 0,
381+
encoder_trainable: bool = True,
381382
prediction_embed_dim: int = 512,
382383
prediction_embed_dropout: int = 0,
383384
prediction_num_rnns: int = 1,
@@ -386,13 +387,15 @@ def __init__(self,
386387
prediction_rnn_implementation: int = 2,
387388
prediction_layer_norm: bool = True,
388389
prediction_projection_units: int = 0,
390+
prediction_trainable: bool = True,
389391
joint_dim: int = 1024,
390392
joint_activation: str = "tanh",
391393
prejoint_linear: bool = True,
392394
joint_mode: str = "add",
395+
joint_trainable: bool = True,
393396
kernel_regularizer=L2,
394397
bias_regularizer=L2,
395-
name: str = "conformer_transducer",
398+
name: str = "conformer",
396399
**kwargs):
397400
super(Conformer, self).__init__(
398401
encoder=ConformerEncoder(
@@ -408,7 +411,9 @@ def __init__(self,
408411
fc_factor=encoder_fc_factor,
409412
dropout=encoder_dropout,
410413
kernel_regularizer=kernel_regularizer,
411-
bias_regularizer=bias_regularizer
414+
bias_regularizer=bias_regularizer,
415+
trainable=encoder_trainable,
416+
name=f"{name}_encoder"
412417
),
413418
vocabulary_size=vocabulary_size,
414419
embed_dim=prediction_embed_dim,
@@ -419,13 +424,16 @@ def __init__(self,
419424
rnn_implementation=prediction_rnn_implementation,
420425
layer_norm=prediction_layer_norm,
421426
projection_units=prediction_projection_units,
427+
prediction_trainable=prediction_trainable,
422428
joint_dim=joint_dim,
423429
joint_activation=joint_activation,
424430
prejoint_linear=prejoint_linear,
425431
joint_mode=joint_mode,
432+
joint_trainable=joint_trainable,
426433
kernel_regularizer=kernel_regularizer,
427434
bias_regularizer=bias_regularizer,
428-
name=name, **kwargs
435+
name=name,
436+
**kwargs
429437
)
430438
self.dmodel = encoder_dmodel
431439
self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor

tensorflow_asr/models/contextnet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(self,
197197
vocabulary_size: int,
198198
encoder_blocks: List[dict],
199199
encoder_alpha: float = 0.5,
200+
encoder_trainable: bool = True,
200201
prediction_embed_dim: int = 512,
201202
prediction_embed_dropout: int = 0,
202203
prediction_num_rnns: int = 1,
@@ -205,10 +206,12 @@ def __init__(self,
205206
prediction_rnn_implementation: int = 2,
206207
prediction_layer_norm: bool = True,
207208
prediction_projection_units: int = 0,
209+
prediction_trainable: bool = True,
208210
joint_dim: int = 1024,
209211
joint_activation: str = "tanh",
210212
prejoint_linear: bool = True,
211213
joint_mode: str = "add",
214+
joint_trainable: bool = True,
212215
kernel_regularizer=L2,
213216
bias_regularizer=L2,
214217
name: str = "contextnet",
@@ -219,6 +222,7 @@ def __init__(self,
219222
alpha=encoder_alpha,
220223
kernel_regularizer=kernel_regularizer,
221224
bias_regularizer=bias_regularizer,
225+
trainable=encoder_trainable,
222226
name=f"{name}_encoder"
223227
),
224228
vocabulary_size=vocabulary_size,
@@ -229,14 +233,17 @@ def __init__(self,
229233
rnn_type=prediction_rnn_type,
230234
rnn_implementation=prediction_rnn_implementation,
231235
layer_norm=prediction_layer_norm,
236+
prediction_trainable=prediction_trainable,
232237
projection_units=prediction_projection_units,
233238
joint_dim=joint_dim,
234239
joint_activation=joint_activation,
235240
prejoint_linear=prejoint_linear,
236241
joint_mode=joint_mode,
242+
joint_trainable=joint_trainable,
237243
kernel_regularizer=kernel_regularizer,
238244
bias_regularizer=bias_regularizer,
239-
name=name, **kwargs
245+
name=name,
246+
**kwargs
240247
)
241248
self.dmodel = self.encoder.blocks[-1].dmodel
242249
self.time_reduction_factor = 1

tensorflow_asr/models/keras/conformer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self,
3030
encoder_depth_multiplier: int = 1,
3131
encoder_fc_factor: float = 0.5,
3232
encoder_dropout: float = 0,
33+
encoder_trainable: bool = True,
3334
prediction_embed_dim: int = 512,
3435
prediction_embed_dropout: int = 0,
3536
prediction_num_rnns: int = 1,
@@ -38,13 +39,15 @@ def __init__(self,
3839
prediction_rnn_implementation: int = 2,
3940
prediction_layer_norm: bool = True,
4041
prediction_projection_units: int = 0,
42+
prediction_trainable: bool = True,
4143
joint_dim: int = 1024,
4244
joint_activation: str = "tanh",
4345
prejoint_linear: bool = True,
4446
joint_mode: str = "add",
47+
joint_trainable: bool = True,
4548
kernel_regularizer=L2,
4649
bias_regularizer=L2,
47-
name: str = "conformer_transducer",
50+
name: str = "conformer",
4851
**kwargs):
4952
super(Conformer, self).__init__(
5053
encoder=ConformerEncoder(
@@ -60,7 +63,9 @@ def __init__(self,
6063
fc_factor=encoder_fc_factor,
6164
dropout=encoder_dropout,
6265
kernel_regularizer=kernel_regularizer,
63-
bias_regularizer=bias_regularizer
66+
bias_regularizer=bias_regularizer,
67+
trainable=encoder_trainable,
68+
name=f"{name}_encoder"
6469
),
6570
vocabulary_size=vocabulary_size,
6671
embed_dim=prediction_embed_dim,
@@ -71,13 +76,16 @@ def __init__(self,
7176
rnn_implementation=prediction_rnn_implementation,
7277
layer_norm=prediction_layer_norm,
7378
projection_units=prediction_projection_units,
79+
prediction_trainable=prediction_trainable,
7480
joint_dim=joint_dim,
7581
joint_activation=joint_activation,
7682
prejoint_linear=prejoint_linear,
7783
joint_mode=joint_mode,
84+
joint_trainable=joint_trainable,
7885
kernel_regularizer=kernel_regularizer,
7986
bias_regularizer=bias_regularizer,
80-
name=name, **kwargs
87+
name=name,
88+
**kwargs
8189
)
8290
self.dmodel = encoder_dmodel
8391
self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor

tensorflow_asr/models/keras/contextnet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self,
2525
vocabulary_size: int,
2626
encoder_blocks: List[dict],
2727
encoder_alpha: float = 0.5,
28+
encoder_trainable: bool = True,
2829
prediction_embed_dim: int = 512,
2930
prediction_embed_dropout: int = 0,
3031
prediction_num_rnns: int = 1,
@@ -33,10 +34,12 @@ def __init__(self,
3334
prediction_rnn_implementation: int = 2,
3435
prediction_layer_norm: bool = True,
3536
prediction_projection_units: int = 0,
37+
prediction_trainable: bool = True,
3638
joint_dim: int = 1024,
3739
joint_activation: str = "tanh",
3840
prejoint_linear: bool = True,
3941
joint_mode: str = "add",
42+
joint_trainable: bool = True,
4043
kernel_regularizer=L2,
4144
bias_regularizer=L2,
4245
name: str = "contextnet",
@@ -47,6 +50,7 @@ def __init__(self,
4750
alpha=encoder_alpha,
4851
kernel_regularizer=kernel_regularizer,
4952
bias_regularizer=bias_regularizer,
53+
trainable=encoder_trainable,
5054
name=f"{name}_encoder"
5155
),
5256
vocabulary_size=vocabulary_size,
@@ -58,13 +62,16 @@ def __init__(self,
5862
rnn_implementation=prediction_rnn_implementation,
5963
layer_norm=prediction_layer_norm,
6064
projection_units=prediction_projection_units,
65+
prediction_trainable=prediction_trainable,
6166
joint_dim=joint_dim,
6267
joint_activation=joint_activation,
6368
prejoint_linear=prejoint_linear,
6469
joint_mode=joint_mode,
70+
joint_trainable=joint_trainable,
6571
kernel_regularizer=kernel_regularizer,
6672
bias_regularizer=bias_regularizer,
67-
name=name, **kwargs
73+
name=name,
74+
**kwargs
6875
)
6976
self.dmodel = self.encoder.blocks[-1].dmodel
7077
self.time_reduction_factor = 1

tensorflow_asr/models/keras/streaming_transducer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,20 @@ def __init__(self,
2828
encoder_rnn_type: str = "lstm",
2929
encoder_rnn_units: int = 2048,
3030
encoder_layer_norm: bool = True,
31+
encoder_trainable: bool = True,
3132
prediction_embed_dim: int = 320,
3233
prediction_embed_dropout: float = 0,
3334
prediction_num_rnns: int = 2,
3435
prediction_rnn_units: int = 2048,
3536
prediction_rnn_type: str = "lstm",
3637
prediction_layer_norm: bool = True,
3738
prediction_projection_units: int = 640,
39+
prediction_trainable: bool = True,
3840
joint_dim: int = 640,
3941
joint_activation: str = "tanh",
4042
prejoint_linear: bool = True,
4143
joint_mode: str = "add",
44+
joint_trainable: bool = True,
4245
kernel_regularizer = None,
4346
bias_regularizer = None,
4447
name = "StreamingTransducer",
@@ -53,6 +56,7 @@ def __init__(self,
5356
layer_norm=encoder_layer_norm,
5457
kernel_regularizer=kernel_regularizer,
5558
bias_regularizer=bias_regularizer,
59+
trainable=encoder_trainable,
5660
name=f"{name}_encoder"
5761
),
5862
vocabulary_size=vocabulary_size,
@@ -63,10 +67,12 @@ def __init__(self,
6367
rnn_type=prediction_rnn_type,
6468
layer_norm=prediction_layer_norm,
6569
projection_units=prediction_projection_units,
70+
prediction_trainable=prediction_trainable,
6671
joint_dim=joint_dim,
6772
joint_activation=joint_activation,
6873
prejoint_linear=prejoint_linear,
6974
joint_mode=joint_mode,
75+
joint_trainable=joint_trainable,
7076
kernel_regularizer=kernel_regularizer,
7177
bias_regularizer=bias_regularizer,
7278
name=name, **kwargs

tensorflow_asr/models/streaming_transducer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,20 @@ def __init__(self,
183183
encoder_rnn_type: str = "lstm",
184184
encoder_rnn_units: int = 2048,
185185
encoder_layer_norm: bool = True,
186+
encoder_trainable: bool = True,
186187
prediction_embed_dim: int = 320,
187188
prediction_embed_dropout: float = 0,
188189
prediction_num_rnns: int = 2,
189190
prediction_rnn_units: int = 2048,
190191
prediction_rnn_type: str = "lstm",
191192
prediction_layer_norm: bool = True,
192193
prediction_projection_units: int = 640,
194+
prediction_trainable: bool = True,
193195
joint_dim: int = 640,
194196
joint_activation: str = "tanh",
195197
prejoint_linear: bool = True,
196198
joint_mode: str = "add",
199+
joint_trainable: bool = True,
197200
kernel_regularizer = None,
198201
bias_regularizer = None,
199202
name = "StreamingTransducer",
@@ -208,6 +211,7 @@ def __init__(self,
208211
layer_norm=encoder_layer_norm,
209212
kernel_regularizer=kernel_regularizer,
210213
bias_regularizer=bias_regularizer,
214+
trainable=encoder_trainable,
211215
name=f"{name}_encoder"
212216
),
213217
vocabulary_size=vocabulary_size,
@@ -218,13 +222,16 @@ def __init__(self,
218222
rnn_type=prediction_rnn_type,
219223
layer_norm=prediction_layer_norm,
220224
projection_units=prediction_projection_units,
225+
prediction_trainable=prediction_trainable,
221226
joint_dim=joint_dim,
222227
joint_activation=joint_activation,
223228
prejoint_linear=prejoint_linear,
224229
joint_mode=joint_mode,
230+
joint_trainable=joint_trainable,
225231
kernel_regularizer=kernel_regularizer,
226232
bias_regularizer=bias_regularizer,
227-
name=name, **kwargs
233+
name=name,
234+
**kwargs
228235
)
229236
self.time_reduction_factor = self.encoder.time_reduction_factor
230237

tensorflow_asr/models/transducer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,12 @@ def __init__(self,
249249
rnn_implementation: int = 2,
250250
layer_norm: bool = True,
251251
projection_units: int = 0,
252+
prediction_trainable: bool = True,
252253
joint_dim: int = 1024,
253254
joint_activation: str = "tanh",
254255
prejoint_linear: bool = True,
255256
joint_mode: str = "add",
257+
joint_trainable: bool = True,
256258
kernel_regularizer=None,
257259
bias_regularizer=None,
258260
name="transducer",
@@ -271,6 +273,7 @@ def __init__(self,
271273
projection_units=projection_units,
272274
kernel_regularizer=kernel_regularizer,
273275
bias_regularizer=bias_regularizer,
276+
trainable=prediction_trainable,
274277
name=f"{name}_prediction"
275278
)
276279
self.joint_net = TransducerJoint(
@@ -281,6 +284,7 @@ def __init__(self,
281284
joint_mode=joint_mode,
282285
kernel_regularizer=kernel_regularizer,
283286
bias_regularizer=bias_regularizer,
287+
trainable=joint_trainable,
284288
name=f"{name}_joint"
285289
)
286290
self.time_reduction_factor = 1

0 commit comments

Comments
 (0)