|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import os |
16 | | -from tqdm import tqdm |
17 | | -import argparse |
18 | | -from tensorflow_asr.utils import env_util, file_util |
| 16 | +import fire |
| 17 | +from tensorflow_asr.utils import env_util |
19 | 18 |
|
20 | 19 | logger = env_util.setup_environment() |
21 | 20 | import tensorflow as tf |
22 | 21 |
|
23 | | -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") |
24 | | - |
25 | | -tf.keras.backend.clear_session() |
26 | | - |
27 | | -parser = argparse.ArgumentParser(prog="Contextnet Testing") |
28 | | - |
29 | | -parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") |
30 | | - |
31 | | -parser.add_argument("--saved", type=str, default=None, help="Path to saved model") |
32 | | - |
33 | | -parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") |
34 | | - |
35 | | -parser.add_argument("--bs", type=int, default=None, help="Test batch size") |
36 | | - |
37 | | -parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") |
38 | | - |
39 | | -parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords") |
40 | | - |
41 | | -parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") |
42 | | - |
43 | | -parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") |
44 | | - |
45 | | -parser.add_argument("--output", type=str, default="test.tsv", help="Result filepath") |
46 | | - |
47 | | -args = parser.parse_args() |
48 | | - |
49 | | -assert args.saved |
50 | | - |
51 | | -tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) |
52 | | - |
53 | | -env_util.setup_devices([args.device], cpu=args.cpu) |
54 | | - |
55 | 22 | from tensorflow_asr.configs.config import Config |
56 | | -from tensorflow_asr.datasets.asr_dataset import ASRSliceDataset |
57 | | -from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer |
58 | | -from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer, CharFeaturizer |
59 | 23 | from tensorflow_asr.models.transducer.contextnet import ContextNet |
60 | | -from tensorflow_asr.utils import app_util |
| 24 | +from tensorflow_asr.helpers import exec_helpers, dataset_helpers, featurizer_helpers |
61 | 25 |
|
62 | | -config = Config(args.config) |
63 | | -speech_featurizer = TFSpeechFeaturizer(config.speech_config) |
64 | | - |
65 | | -if args.sentence_piece: |
66 | | - logger.info("Use SentencePiece ...") |
67 | | - text_featurizer = SentencePieceFeaturizer(config.decoder_config) |
68 | | -elif args.subwords: |
69 | | - logger.info("Use subwords ...") |
70 | | - text_featurizer = SubwordFeaturizer(config.decoder_config) |
71 | | -else: |
72 | | - logger.info("Use characters ...") |
73 | | - text_featurizer = CharFeaturizer(config.decoder_config) |
74 | | - |
75 | | -tf.random.set_seed(0) |
76 | | - |
77 | | -test_dataset = ASRSliceDataset( |
78 | | - speech_featurizer=speech_featurizer, |
79 | | - text_featurizer=text_featurizer, |
80 | | - **vars(config.learning_config.test_dataset_config) |
81 | | -) |
82 | | - |
83 | | -# build model |
84 | | -contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) |
85 | | -contextnet.make(speech_featurizer.shape) |
86 | | -contextnet.load_weights(args.saved, by_name=True) |
87 | | -contextnet.summary(line_length=100) |
88 | | -contextnet.add_featurizers(speech_featurizer, text_featurizer) |
| 26 | +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") |
89 | 27 |
|
90 | | -batch_size = args.bs or config.learning_config.running_config.batch_size |
91 | | -test_data_loader = test_dataset.create(batch_size) |
92 | 28 |
|
93 | | -with file_util.save_file(file_util.preprocess_paths(args.output)) as filepath: |
94 | | - overwrite = True |
95 | | - if tf.io.gfile.exists(filepath): |
96 | | - overwrite = input(f"Overwrite existing result file {filepath} ? (y/n): ").lower() == "y" |
97 | | - if overwrite: |
98 | | - results = contextnet.predict(test_data_loader, verbose=1) |
99 | | - logger.info(f"Saving result to {args.output} ...") |
100 | | - with open(filepath, "w") as openfile: |
101 | | - openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n") |
102 | | - progbar = tqdm(total=test_dataset.total_steps, unit="batch") |
103 | | - for i, pred in enumerate(results): |
104 | | - groundtruth, greedy, beamsearch = [x.decode('utf-8') for x in pred] |
105 | | - path, duration, _ = test_dataset.entries[i] |
106 | | - openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n") |
107 | | - progbar.update(1) |
108 | | - progbar.close() |
109 | | - app_util.evaluate_results(filepath) |
| 29 | +def main( |
| 30 | + config: str = DEFAULT_YAML, |
| 31 | + saved: str = None, |
| 32 | + mxp: bool = False, |
| 33 | + bs: int = None, |
| 34 | + sentence_piece: bool = False, |
| 35 | + subwords: bool = False, |
| 36 | + device: int = 0, |
| 37 | + cpu: bool = False, |
| 38 | + output: str = "test.tsv", |
| 39 | +): |
| 40 | + assert saved and output |
| 41 | + tf.random.set_seed(0) |
| 42 | + tf.keras.backend.clear_session() |
| 43 | + tf.config.optimizer.set_experimental_options({"auto_mixed_precision": mxp}) |
| 44 | + env_util.setup_devices([device], cpu=cpu) |
| 45 | + |
| 46 | + config = Config(config) |
| 47 | + |
| 48 | + speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers( |
| 49 | + config=config, |
| 50 | + subwords=subwords, |
| 51 | + sentence_piece=sentence_piece, |
| 52 | + ) |
| 53 | + |
| 54 | + contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) |
| 55 | + contextnet.make(speech_featurizer.shape) |
| 56 | + contextnet.load_weights(saved, by_name=True) |
| 57 | + contextnet.summary(line_length=100) |
| 58 | + contextnet.add_featurizers(speech_featurizer, text_featurizer) |
| 59 | + |
| 60 | + test_dataset = dataset_helpers.prepare_testing_datasets( |
| 61 | + config=config, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer |
| 62 | + ) |
| 63 | + batch_size = bs or config.learning_config.running_config.batch_size |
| 64 | + test_data_loader = test_dataset.create(batch_size) |
| 65 | + |
| 66 | + exec_helpers.run_testing(model=contextnet, test_dataset=test_dataset, test_data_loader=test_data_loader, output=output) |
| 67 | + |
| 68 | + |
| 69 | +if __name__ == "__main__": |
| 70 | + fire.Fire(main) |
0 commit comments