Skip to content

Commit 2c90f67

Browse files
committed
✍️ update examples train scripts
1 parent 6114fb6 commit 2c90f67

File tree

21 files changed

+158
-50
lines changed

21 files changed

+158
-50
lines changed

examples/conformer/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
# build model
8484
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
8585
conformer.make(speech_featurizer.shape)
86-
conformer.load_weights(args.saved)
86+
conformer.load_weights(args.saved, by_name=True)
8787
conformer.summary(line_length=100)
8888
conformer.add_featurizers(speech_featurizer, text_featurizer)
8989

examples/conformer/tflite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# build model
5454
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
5555
conformer.make(speech_featurizer.shape)
56-
conformer.load_weights(args.saved)
56+
conformer.load_weights(args.saved, by_name=True)
5757
conformer.summary(line_length=100)
5858
conformer.add_featurizers(speech_featurizer, text_featurizer)
5959

examples/conformer/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
4848

49+
parser.add_argument("--pretrained", type=str, default=None, help="Path to pretrained model")
50+
4951
args = parser.parse_args()
5052

5153
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
@@ -119,8 +121,9 @@
119121
prediction_shape=text_featurizer.prepand_shape,
120122
batch_size=global_batch_size
121123
)
124+
if args.pretrained:
125+
conformer.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
122126
conformer.summary(line_length=100)
123-
124127
optimizer = tf.keras.optimizers.Adam(
125128
TransformerSchedule(
126129
d_model=conformer.dmodel,
@@ -129,7 +132,6 @@
129132
),
130133
**config.learning_config.optimizer_config
131134
)
132-
133135
conformer.compile(
134136
optimizer=optimizer,
135137
experimental_steps_per_execution=args.spx,

examples/contextnet/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
# build model
8484
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
8585
contextnet.make(speech_featurizer.shape)
86-
contextnet.load_weights(args.saved)
86+
contextnet.load_weights(args.saved, by_name=True)
8787
contextnet.summary(line_length=100)
8888
contextnet.add_featurizers(speech_featurizer, text_featurizer)
8989

examples/contextnet/tflite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# build model
5454
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
5555
contextnet.make(speech_featurizer.shape)
56-
contextnet.load_weights(args.saved)
56+
contextnet.load_weights(args.saved, by_name=True)
5757
contextnet.summary(line_length=100)
5858
contextnet.add_featurizers(speech_featurizer, text_featurizer)
5959

examples/contextnet/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
4848

49+
parser.add_argument("--pretrained", type=str, default=None, help="Path to pretrained model")
50+
4951
args = parser.parse_args()
5052

5153
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
@@ -119,8 +121,9 @@
119121
prediction_shape=text_featurizer.prepand_shape,
120122
batch_size=global_batch_size
121123
)
124+
if args.pretrained:
125+
contextnet.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
122126
contextnet.summary(line_length=100)
123-
124127
optimizer = tf.keras.optimizers.Adam(
125128
TransformerSchedule(
126129
d_model=contextnet.dmodel,
@@ -129,7 +132,6 @@
129132
),
130133
**config.learning_config.optimizer_config
131134
)
132-
133135
contextnet.compile(
134136
optimizer=optimizer,
135137
experimental_steps_per_execution=args.spx,

examples/deepspeech2/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
# build model
8484
deepspeech2 = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes)
8585
deepspeech2.make(speech_featurizer.shape)
86-
deepspeech2.load_weights(args.saved)
86+
deepspeech2.load_weights(args.saved, by_name=True)
8787
deepspeech2.summary(line_length=100)
8888
deepspeech2.add_featurizers(speech_featurizer, text_featurizer)
8989

examples/deepspeech2/tflite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# build model
5454
deepspeech2 = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes)
5555
deepspeech2.make(speech_featurizer.shape)
56-
deepspeech2.load_weights(args.saved)
56+
deepspeech2.load_weights(args.saved, by_name=True)
5757
deepspeech2.summary(line_length=100)
5858
deepspeech2.add_featurizers(speech_featurizer, text_featurizer)
5959

examples/deepspeech2/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545

4646
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
4747

48+
parser.add_argument("--pretrained", type=str, default=None, help="Path to pretrained model")
49+
4850
args = parser.parse_args()
4951

5052
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
@@ -113,6 +115,8 @@
113115
# build model
114116
deepspeech2 = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes)
115117
deepspeech2.make(speech_featurizer.shape, batch_size=global_batch_size)
118+
if args.pretrained:
119+
deepspeech2.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
116120
deepspeech2.summary(line_length=100)
117121
deepspeech2.compile(
118122
optimizer=config.learning_config.optimizer_config,

examples/demonstration/conformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
# build model
6565
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
6666
conformer.make(speech_featurizer.shape)
67-
conformer.load_weights(args.saved)
67+
conformer.load_weights(args.saved, by_name=True, skip_mismatch=True)
6868
conformer.summary(line_length=120)
6969
conformer.add_featurizers(speech_featurizer, text_featurizer)
7070

0 commit comments

Comments
 (0)