Skip to content

Commit c87fca6

Browse files
committed
fix: update saved model
1 parent fec1051 commit c87fca6

File tree

2 files changed

+17
-34
lines changed

2 files changed

+17
-34
lines changed

examples/conformer/saved_model.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,15 @@
2626

2727
parser = argparse.ArgumentParser(prog="Conformer Testing")
2828

29-
parser.add_argument(
30-
"--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file",
31-
)
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3230

33-
parser.add_argument(
34-
"--h5", type=str, default=None, help="Path to saved h5 weights",
35-
)
31+
parser.add_argument("--h5", type=str, default=None, help="Path to saved h5 weights")
3632

37-
parser.add_argument(
38-
"--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model",
39-
)
33+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
4034

41-
parser.add_argument(
42-
"--subwords", default=False, action="store_true", help="Use subwords",
43-
)
35+
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
4436

45-
parser.add_argument(
46-
"--output_dir", type=str, default=None, help="Output directory for saved model",
47-
)
37+
parser.add_argument("--output_dir", type=str, default=None, help="Output directory for saved model")
4838

4939
args = parser.parse_args()
5040

@@ -79,23 +69,14 @@
7969
conformer.add_featurizers(speech_featurizer, text_featurizer)
8070

8171

82-
class aModule(tf.Module):
83-
def __init__(self, model):
84-
super().__init__()
85-
self.model = model
72+
# TODO: Support saved model conversion
73+
# class ConformerModule(tf.Module):
74+
# def __init__(self, model: Conformer, name=None):
75+
# super().__init__(name=name)
76+
# self.model = model
77+
# self.pred = model.make_tflite_function()
8678

87-
@tf.function(
88-
input_signature=[
89-
{
90-
"inputs": tf.TensorSpec(shape=[None, None, 80, 1], dtype=tf.float32, name="inputs"),
91-
"inputs_length": tf.TensorSpec(shape=[None], dtype=tf.int32, name="inputs_length"),
92-
}
93-
]
94-
)
95-
def pred(self, input_batch):
96-
result = self.model.recognize(input_batch)
97-
return {"ASR": result}
9879

99-
100-
module = aModule(conformer)
101-
tf.saved_model.save(module, args.output_dir, signatures={"serving_default": module.pred})
80+
# model = ConformerModule(model=conformer)
81+
# tf.saved_model.save(model, args.output_dir)
82+
conformer.save(args.output_dir, include_optimizer=False, save_format="tf")

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ sentencepiece==0.1.96
66
tqdm==4.62.1
77
librosa==0.8.1
88
PyYAML==5.4.1
9-
Pillow==8.3.2
9+
Pillow==8.3.2
10+
black==21.7b0
11+
flake8==3.9.2

0 commit comments

Comments
 (0)