|
26 | 26 |
|
27 | 27 | parser = argparse.ArgumentParser(prog="Conformer Testing") |
28 | 28 |
|
29 | | -parser.add_argument( |
30 | | - "--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file", |
31 | | -) |
| 29 | +parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") |
32 | 30 |
|
33 | | -parser.add_argument( |
34 | | - "--h5", type=str, default=None, help="Path to saved h5 weights", |
35 | | -) |
| 31 | +parser.add_argument("--h5", type=str, default=None, help="Path to saved h5 weights") |
36 | 32 |
|
37 | | -parser.add_argument( |
38 | | - "--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model", |
39 | | -) |
| 33 | +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") |
40 | 34 |
|
41 | | -parser.add_argument( |
42 | | - "--subwords", default=False, action="store_true", help="Use subwords", |
43 | | -) |
| 35 | +parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords") |
44 | 36 |
|
45 | | -parser.add_argument( |
46 | | - "--output_dir", type=str, default=None, help="Output directory for saved model", |
47 | | -) |
| 37 | +parser.add_argument("--output_dir", type=str, default=None, help="Output directory for saved model") |
48 | 38 |
|
49 | 39 | args = parser.parse_args() |
50 | 40 |
|
|
79 | 69 | conformer.add_featurizers(speech_featurizer, text_featurizer) |
80 | 70 |
|
81 | 71 |
|
82 | | -class aModule(tf.Module): |
83 | | - def __init__(self, model): |
84 | | - super().__init__() |
85 | | - self.model = model |
| 72 | +# TODO: Support saved model conversion |
| 73 | +# class ConformerModule(tf.Module): |
| 74 | +# def __init__(self, model: Conformer, name=None): |
| 75 | +# super().__init__(name=name) |
| 76 | +# self.model = model |
| 77 | +# self.pred = model.make_tflite_function() |
86 | 78 |
|
87 | | - @tf.function( |
88 | | - input_signature=[ |
89 | | - { |
90 | | - "inputs": tf.TensorSpec(shape=[None, None, 80, 1], dtype=tf.float32, name="inputs"), |
91 | | - "inputs_length": tf.TensorSpec(shape=[None], dtype=tf.int32, name="inputs_length"), |
92 | | - } |
93 | | - ] |
94 | | - ) |
95 | | - def pred(self, input_batch): |
96 | | - result = self.model.recognize(input_batch) |
97 | | - return {"ASR": result} |
98 | 79 |
|
99 | | - |
100 | | -module = aModule(conformer) |
101 | | -tf.saved_model.save(module, args.output_dir, signatures={"serving_default": module.pred}) |
| 80 | +# model = ConformerModule(model=conformer) |
| 81 | +# tf.saved_model.save(model, args.output_dir) |
| 82 | +conformer.save(args.output_dir, include_optimizer=False, save_format="tf") |
0 commit comments