Skip to content

Commit 51d8c55

Browse files
committed
✍️ update example scripts
1 parent 4dbbb17 commit 51d8c55

File tree

15 files changed

+655
-490
lines changed

15 files changed

+655
-490
lines changed

examples/conformer/tflite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from tensorflow_asr.configs.config import Config
2323
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
2424
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, CharFeaturizer
25-
from tensorflow_asr.models.conformer import Conformer
25+
from tensorflow_asr.models.transducer.conformer import Conformer
2626

2727
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2828

2929
tf.keras.backend.clear_session()
3030

31-
parser = argparse.ArgumentParser(prog="Conformer Testing")
31+
parser = argparse.ArgumentParser(prog="Conformer TFLite")
3232

3333
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3434

examples/contextnet/test.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# limitations under the License.
1414

1515
import os
16+
from tqdm import tqdm
1617
import argparse
17-
from tensorflow_asr.utils import setup_environment, setup_devices
18+
from tensorflow_asr.utils import env_util, file_util
1819

19-
setup_environment()
20+
env_util.setup_environment()
2021
import tensorflow as tf
2122

2223
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2324

2425
tf.keras.backend.clear_session()
2526

26-
parser = argparse.ArgumentParser(prog="ContextNet Testing")
27+
parser = argparse.ArgumentParser(prog="Contextnet Testing")
2728

2829
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
2930

@@ -33,60 +34,85 @@
3334

3435
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
3536

37+
parser.add_argument("--bs", type=int, default=None, help="Test batch size")
38+
39+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
40+
41+
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
42+
3643
parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")
3744

3845
parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
3946

40-
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
41-
42-
parser.add_argument("--output_name", type=str, default="test", help="Result filename name prefix")
47+
parser.add_argument("--output", type=str, default="test.tsv", help="Result filepath")
4348

4449
args = parser.parse_args()
4550

51+
assert args.saved
52+
4653
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
4754

48-
setup_devices([args.device], cpu=args.cpu)
55+
env_util.setup_devices([args.device], cpu=args.cpu)
4956

5057
from tensorflow_asr.configs.config import Config
5158
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
5259
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
53-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
54-
from tensorflow_asr.runners.base_runners import BaseTester
55-
from tensorflow_asr.models.contextnet import ContextNet
60+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer, CharFeaturizer
61+
from tensorflow_asr.models.transducer.contextnet import ContextNet
62+
from tensorflow_asr.utils import app_util
5663

5764
config = Config(args.config)
5865
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
5966

60-
if args.subwords and os.path.exists(args.subwords):
61-
print("Loading subwords ...")
62-
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
67+
if args.sentence_piece:
68+
print("Use SentencePiece ...")
69+
text_featurizer = SentencePieceFeaturizer(config.decoder_config)
70+
elif args.subwords:
71+
print("Use subwords ...")
72+
text_featurizer = SubwordFeaturizer(config.decoder_config)
6373
else:
64-
raise ValueError("subwords must be set")
74+
print("Use characters ...")
75+
text_featurizer = CharFeaturizer(config.decoder_config)
6576

6677
tf.random.set_seed(0)
67-
assert args.saved
6878

6979
if args.tfrecords:
7080
test_dataset = ASRTFRecordDataset(
71-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
81+
speech_featurizer=speech_featurizer,
82+
text_featurizer=text_featurizer,
7283
**vars(config.learning_config.test_dataset_config)
7384
)
7485
else:
7586
test_dataset = ASRSliceDataset(
76-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
87+
speech_featurizer=speech_featurizer,
88+
text_featurizer=text_featurizer,
7789
**vars(config.learning_config.test_dataset_config)
7890
)
7991

8092
# build model
8193
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
8294
contextnet._build(speech_featurizer.shape)
8395
contextnet.load_weights(args.saved)
84-
contextnet.summary(line_length=120)
96+
contextnet.summary(line_length=100)
8597
contextnet.add_featurizers(speech_featurizer, text_featurizer)
8698

87-
contextnet_tester = BaseTester(
88-
config=config.learning_config.running_config,
89-
output_name=args.output_name
90-
)
91-
contextnet_tester.compile(contextnet)
92-
contextnet_tester.run(test_dataset)
99+
batch_size = args.bs or config.learning_config.running_config.batch_size
100+
test_data_loader = test_dataset.create(batch_size)
101+
102+
with file_util.save_file(file_util.preprocess_paths(args.output)) as filepath:
103+
overwrite = False
104+
if tf.io.gfile.exists(filepath):
105+
overwrite = input("Overwrite existing result file? (y/n): ").lower() == "y"
106+
if overwrite:
107+
results = contextnet.predict(test_data_loader, verbose=1)
108+
print(f"Saving result to {args.output} ...")
109+
with open(filepath, "w") as openfile:
110+
openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n")
111+
progbar = tqdm(total=test_dataset.total_steps, unit="batch")
112+
for i, pred in enumerate(results):
113+
groundtruth, greedy, beamsearch = [x.decode('utf-8') for x in pred]
114+
path, duration, _ = test_dataset.entries[i]
115+
openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n")
116+
progbar.update(1)
117+
progbar.close()
118+
app_util.evaluate_results(filepath)

examples/contextnet/tflite.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,29 @@
1414

1515
import os
1616
import argparse
17-
from tensorflow_asr.utils import setup_environment
17+
from tensorflow_asr.utils import env_util, file_util
1818

19-
setup_environment()
19+
env_util.setup_environment()
2020
import tensorflow as tf
2121

2222
from tensorflow_asr.configs.config import Config
2323
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
24-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
25-
from tensorflow_asr.models.contextnet import ContextNet
24+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, CharFeaturizer
25+
from tensorflow_asr.models.transducer.contextnet import ContextNet
2626

2727
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2828

2929
tf.keras.backend.clear_session()
3030

31-
parser = argparse.ArgumentParser(prog="ContextNet Testing")
31+
parser = argparse.ArgumentParser(prog="ContextNet TFLite")
3232

33-
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
34-
help="The file path of model configuration file")
33+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3534

36-
parser.add_argument("--saved", type=str, default=None,
37-
help="Path to saved model")
35+
parser.add_argument("--saved", type=str, default=None, help="Path to saved model")
3836

39-
parser.add_argument("--subwords", type=str, default=None,
40-
help="Path to file that stores generated subwords")
37+
parser.add_argument("--subwords", type=str, default=None, help="Use subwords")
4138

42-
parser.add_argument("output", type=str, default=None,
43-
help="TFLite file path to be exported")
39+
parser.add_argument("output", type=str, default=None, help="TFLite file path to be exported")
4440

4541
args = parser.parse_args()
4642

@@ -49,27 +45,25 @@
4945
config = Config(args.config)
5046
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
5147

52-
if args.subwords and os.path.exists(args.subwords):
53-
print("Loading subwords ...")
54-
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
48+
if args.subwords:
49+
text_featurizer = SubwordFeaturizer(config.decoder_config)
5550
else:
56-
raise ValueError("subwords must be set")
51+
text_featurizer = CharFeaturizer(config.decoder_config)
5752

5853
# build model
5954
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
6055
contextnet._build(speech_featurizer.shape)
6156
contextnet.load_weights(args.saved)
62-
contextnet.summary(line_length=150)
57+
contextnet.summary(line_length=100)
6358
contextnet.add_featurizers(speech_featurizer, text_featurizer)
6459

6560
concrete_func = contextnet.make_tflite_function().get_concrete_function()
6661
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
62+
converter.experimental_new_converter = True
6763
converter.optimizations = [tf.lite.Optimize.DEFAULT]
68-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
69-
tf.lite.OpsSet.SELECT_TF_OPS]
64+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
7065
tflite_model = converter.convert()
7166

72-
if not os.path.exists(os.path.dirname(args.output)):
73-
os.makedirs(os.path.dirname(args.output))
67+
args.output = file_util.preprocess_paths(args.output)
7468
with open(args.output, "wb") as tflite_out:
7569
tflite_out.write(tflite_model)

examples/contextnet/train.py

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,96 +15,99 @@
1515
import os
1616
import math
1717
import argparse
18-
from tensorflow_asr.utils import setup_environment, setup_strategy
18+
from tensorflow_asr.utils import env_util
1919

20-
setup_environment()
20+
env_util.setup_environment()
2121
import tensorflow as tf
2222

2323
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2424

2525
tf.keras.backend.clear_session()
2626

27-
parser = argparse.ArgumentParser(prog="ContextNet Training")
27+
parser = argparse.ArgumentParser(prog="Contextnet Training")
2828

2929
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3030

31-
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")
32-
3331
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")
3432

33+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
34+
35+
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
36+
3537
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")
3638

3739
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")
3840

3941
parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance")
4042

41-
parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata")
43+
parser.add_argument("--metadata", type=str, default=None, help="Path to file containing metadata")
44+
45+
parser.add_argument("--static_length", default=False, action="store_true", help="Use static lengths")
4246

4347
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")
4448

4549
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
4650

47-
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
48-
49-
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")
50-
5151
args = parser.parse_args()
5252

5353
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
5454

55-
strategy = setup_strategy(args.devices)
55+
strategy = env_util.setup_strategy(args.devices)
5656

5757
from tensorflow_asr.configs.config import Config
58-
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras
59-
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
60-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
61-
from tensorflow_asr.models.keras.contextnet import ContextNet
58+
from tensorflow_asr.datasets import asr_dataset
59+
from tensorflow_asr.featurizers import speech_featurizers, text_featurizers
60+
from tensorflow_asr.models.transducer.contextnet import ContextNet
6261
from tensorflow_asr.optimizers.schedules import TransformerSchedule
6362

6463
config = Config(args.config)
65-
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
64+
speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)
6665

67-
if args.subwords and os.path.exists(args.subwords):
66+
if args.sentence_piece:
67+
print("Loading SentencePiece model ...")
68+
text_featurizer = text_featurizers.SentencePieceFeaturizer(config.decoder_config)
69+
elif args.subwords:
6870
print("Loading subwords ...")
69-
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
71+
text_featurizer = text_featurizers.SubwordFeaturizer(config.decoder_config)
7072
else:
71-
print("Generating subwords ...")
72-
text_featurizer = SubwordFeaturizer.build_from_corpus(
73-
config.decoder_config,
74-
corpus_files=args.subwords_corpus
75-
)
76-
text_featurizer.save_to_file(args.subwords)
73+
print("Use characters ...")
74+
text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)
7775

7876
if args.tfrecords:
79-
train_dataset = ASRTFRecordDatasetKeras(
80-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
77+
train_dataset = asr_dataset.ASRTFRecordDataset(
78+
speech_featurizer=speech_featurizer,
79+
text_featurizer=text_featurizer,
8180
**vars(config.learning_config.train_dataset_config),
8281
indefinite=True
8382
)
84-
eval_dataset = ASRTFRecordDatasetKeras(
85-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
83+
eval_dataset = asr_dataset.ASRTFRecordDataset(
84+
speech_featurizer=speech_featurizer,
85+
text_featurizer=text_featurizer,
8686
**vars(config.learning_config.eval_dataset_config),
8787
indefinite=True
8888
)
89-
# Update metadata calculated from both train and eval datasets
90-
train_dataset.load_metadata(args.metadata_prefix)
91-
eval_dataset.load_metadata(args.metadata_prefix)
92-
# Use dynamic length
93-
speech_featurizer.reset_length()
94-
text_featurizer.reset_length()
9589
else:
96-
train_dataset = ASRSliceDatasetKeras(
97-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
90+
train_dataset = asr_dataset.ASRSliceDataset(
91+
speech_featurizer=speech_featurizer,
92+
text_featurizer=text_featurizer,
9893
**vars(config.learning_config.train_dataset_config),
9994
indefinite=True
10095
)
101-
eval_dataset = ASRSliceDatasetKeras(
102-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
96+
eval_dataset = asr_dataset.ASRSliceDataset(
97+
speech_featurizer=speech_featurizer,
98+
text_featurizer=text_featurizer,
10399
**vars(config.learning_config.eval_dataset_config),
104100
indefinite=True
105101
)
106102

107-
global_batch_size = config.learning_config.running_config.batch_size
103+
train_dataset.load_metadata(args.metadata)
104+
eval_dataset.load_metadata(args.metadata)
105+
106+
if not args.static_length:
107+
speech_featurizer.reset_length()
108+
text_featurizer.reset_length()
109+
110+
global_batch_size = args.tbs or config.learning_config.running_config.batch_size
108111
global_batch_size *= strategy.num_replicas_in_sync
109112

110113
train_data_loader = train_dataset.create(global_batch_size)
@@ -114,17 +117,15 @@
114117
# build model
115118
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
116119
contextnet._build(speech_featurizer.shape)
117-
contextnet.summary(line_length=120)
120+
contextnet.summary(line_length=100)
118121

119122
optimizer = tf.keras.optimizers.Adam(
120123
TransformerSchedule(
121124
d_model=contextnet.dmodel,
122-
warmup_steps=config.learning_config.optimizer_config["warmup_steps"],
125+
warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000),
123126
max_lr=(0.05 / math.sqrt(contextnet.dmodel))
124127
),
125-
beta_1=config.learning_config.optimizer_config["beta1"],
126-
beta_2=config.learning_config.optimizer_config["beta2"],
127-
epsilon=config.learning_config.optimizer_config["epsilon"]
128+
**config.learning_config.optimizer_config
128129
)
129130

130131
contextnet.compile(
@@ -141,7 +142,10 @@
141142
]
142143

143144
contextnet.fit(
144-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
145-
validation_data=eval_data_loader, callbacks=callbacks,
146-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
145+
train_data_loader,
146+
epochs=config.learning_config.running_config.num_epochs,
147+
validation_data=eval_data_loader,
148+
callbacks=callbacks,
149+
steps_per_epoch=train_dataset.total_steps,
150+
validation_steps=eval_dataset.total_steps
147151
)

0 commit comments

Comments
 (0)