|
14 | 14 |
|
15 | 15 | import os |
16 | 16 | import argparse |
17 | | -from tensorflow_asr.utils import setup_environment, setup_devices |
| 17 | +from tensorflow_asr.utils import env_util, file_util |
18 | 18 |
|
19 | | -setup_environment() |
| 19 | +env_util.setup_environment() |
20 | 20 | import tensorflow as tf |
21 | 21 |
|
22 | 22 | DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") |
|
33 | 33 |
|
34 | 34 | parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") |
35 | 35 |
|
| 36 | +parser.add_argument("--bs", type=int, default=None, help="Test batch size") |
| 37 | + |
36 | 38 | parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") |
37 | 39 |
|
| 40 | +parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords") |
| 41 | + |
38 | 42 | parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") |
39 | 43 |
|
40 | 44 | parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") |
41 | 45 |
|
42 | | -parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") |
43 | | - |
44 | | -parser.add_argument("--output_name", type=str, default="test", help="Result filename name prefix") |
| 46 | +parser.add_argument("--output", type=str, default="test.tsv", help="Result filepath") |
45 | 47 |
|
46 | 48 | args = parser.parse_args() |
47 | 49 |
|
| 50 | +assert args.saved |
| 51 | + |
48 | 52 | tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) |
49 | 53 |
|
50 | | -setup_devices([args.device], cpu=args.cpu) |
| 54 | +env_util.setup_devices([args.device], cpu=args.cpu) |
51 | 55 |
|
52 | 56 | from tensorflow_asr.configs.config import Config |
53 | 57 | from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset |
54 | 58 | from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer |
55 | | -from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer |
56 | | -from tensorflow_asr.runners.base_runners import BaseTester |
57 | | -from tensorflow_asr.models.conformer import Conformer |
| 59 | +from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer, CharFeaturizer |
| 60 | +from tensorflow_asr.models.transducer.conformer import Conformer |
58 | 61 |
|
59 | 62 | config = Config(args.config) |
60 | 63 | speech_featurizer = TFSpeechFeaturizer(config.speech_config) |
61 | 64 |
|
62 | 65 | if args.sentence_piece: |
63 | | - print("Loading SentencePiece model ...") |
64 | | - text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) |
65 | | -elif args.subwords and os.path.exists(args.subwords): |
66 | | - print("Loading subwords ...") |
67 | | - text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) |
| 66 | + print("Use SentencePiece ...") |
| 67 | + text_featurizer = SentencePieceFeaturizer(config.decoder_config) |
| 68 | +elif args.subwords: |
| 69 | + print("Use subwords ...") |
| 70 | + text_featurizer = SubwordFeaturizer(config.decoder_config) |
68 | 71 | else: |
69 | | - raise ValueError("subwords must be set") |
| 72 | + print("Use characters ...") |
| 73 | + text_featurizer = CharFeaturizer(config.decoder_config) |
70 | 74 |
|
71 | 75 | tf.random.set_seed(0) |
72 | | -assert args.saved |
73 | 76 |
|
74 | 77 | if args.tfrecords: |
75 | 78 | test_dataset = ASRTFRecordDataset( |
76 | | - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
| 79 | + speech_featurizer=speech_featurizer, |
| 80 | + text_featurizer=text_featurizer, |
77 | 81 | **vars(config.learning_config.test_dataset_config) |
78 | 82 | ) |
79 | 83 | else: |
80 | 84 | test_dataset = ASRSliceDataset( |
81 | | - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
| 85 | + speech_featurizer=speech_featurizer, |
| 86 | + text_featurizer=text_featurizer, |
82 | 87 | **vars(config.learning_config.test_dataset_config) |
83 | 88 | ) |
84 | 89 |
|
85 | 90 | # build model |
86 | 91 | conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) |
87 | 92 | conformer._build(speech_featurizer.shape) |
88 | 93 | conformer.load_weights(args.saved) |
89 | | -conformer.summary(line_length=120) |
| 94 | +conformer.summary(line_length=100) |
90 | 95 | conformer.add_featurizers(speech_featurizer, text_featurizer) |
91 | 96 |
|
92 | | -conformer_tester = BaseTester( |
93 | | - config=config.learning_config.running_config, |
94 | | - output_name=args.output_name |
95 | | -) |
96 | | -conformer_tester.compile(conformer) |
97 | | -conformer_tester.run(test_dataset) |
| 97 | +batch_size = args.bs or config.learning_config.running_config.batch_size |
| 98 | +test_data_loader = test_dataset.create(batch_size) |
| 99 | + |
| 100 | +results = conformer.predict(test_data_loader) |
| 101 | + |
| 102 | +with file_util.save_file(file_util.preprocess_paths(args.output)) as filepath: |
| 103 | + print(f"Saving result to {args.output} ...") |
| 104 | + with open(filepath, "w") as openfile: |
| 105 | + openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n") |
| 106 | + for i, entry in test_dataset.entries: |
| 107 | + groundtruth, greedy, beamsearch = results[i] |
| 108 | + path, duration, _ = entry |
| 109 | + openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n") |
0 commit comments