Skip to content

Commit 6105f5b

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add tests to Cadence examples (#6233)
Summary: As titled. Adds deepfilternet, encodec and tacotron. Note that some of them cannot run yet and are under investigation, checking them in so that Cadence can see it on their side. Differential Revision: D64404190
1 parent 708c6b6 commit 6105f5b

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import logging
10+
11+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
12+
13+
import torch
14+
from df.enhance import init_df
15+
16+
from executorch.backends.cadence.aot.export_example import export_model
17+
18+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
19+
logging.basicConfig(level=logging.INFO, format=FORMAT)
20+
21+
22+
def main() -> None:
23+
logging.info("DeepFilterNet requires pip install -q -U deepfilternet")
24+
model, _ , _= init_df()
25+
model.eval()
26+
27+
spec = torch.randn(1, 1, 1061, 481, 2)
28+
spec_feat = torch.randn(1, 1, 1061, 96, 2)
29+
erb_feat = torch.randn(1, 1, 1061, 32)
30+
31+
example_inputs = (
32+
spec,
33+
erb_feat,
34+
spec_feat,
35+
)
36+
37+
export_model(model, example_inputs)
38+
39+
40+
if __name__ == "__main__":
41+
main()

examples/cadence/models/encodec.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import logging
10+
11+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
12+
13+
import torch
14+
from encodec import EncodecModel
15+
16+
from executorch.backends.cadence.aot.export_example import export_model
17+
18+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
19+
logging.basicConfig(level=logging.INFO, format=FORMAT)
20+
21+
22+
def main() -> None:
23+
logging.info("Encodec requires pip install -U encodec")
24+
model = EncodecModel.encodec_model_24khz()
25+
model.set_target_bandwidth(6.0)
26+
model.eval()
27+
28+
wav = torch.randn(1, 1, 144000)
29+
30+
example_inputs = (wav,)
31+
32+
export_model(model, example_inputs)
33+
34+
35+
if __name__ == "__main__":
36+
main()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import logging
10+
11+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
12+
13+
import torchaudio
14+
15+
from executorch.backends.cadence.aot.export_example import export_model
16+
17+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
18+
logging.basicConfig(level=logging.INFO, format=FORMAT)
19+
20+
21+
def main() -> None:
22+
text = "Hello world! T T S stands for Text to Speech!"
23+
24+
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
25+
processor = bundle.get_text_processor()
26+
model = bundle.get_tacotron2()
27+
model.eval()
28+
29+
tokens, tokens_length = processor(text)
30+
specgram, specgram_lengths, _ = model.infer(tokens, tokens_length)
31+
32+
inputs = (
33+
tokens,
34+
tokens_length,
35+
specgram,
36+
specgram_lengths,
37+
)
38+
39+
# Full model
40+
export_model(model, inputs)
41+
42+
43+
if __name__ == "__main__":
44+
main()

0 commit comments

Comments
 (0)