Skip to content

Commit 317fc37

Browse files
committed
Pull request pytorch#119: [EIEX-237] Add MicroSpeechLSTM to test models
Merge in AITEC/executorch from feature/nxf93343/microspeech-lstm-model to main-nxp * commit '88090c585743a85f5c3e4abbeee2e435e9aea565': Add MicroSpeechLSTM to test models
2 parents c0677c3 + 88090c5 commit 317fc37

File tree

7 files changed

+68
-1
lines changed

7 files changed

+68
-1
lines changed

backends/nxp/tests/exported_program_vizualize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def name_color(string): # pseudo-randomization function
102102
label = ""
103103
if "val" in node.meta:
104104
tensor = node.meta["val"]
105-
if isinstance(tensor, tuple):
105+
if isinstance(tensor, tuple) or isinstance(tensor, list):
106106
tensor = tensor[0] # Fake tensor
107107
label = f" ({list(tensor.shape)} | {tensor.dtype})"
108108

examples/nxp/aot_neutron_compile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.examples.models.model_factory import EagerModelFactory
2424
from executorch.examples.nxp.cifar_net.cifar_net import CifarNet
2525
from executorch.examples.nxp.cifar_net.cifar_net import test_cifarnet_model
26+
from executorch.examples.nxp.models.microspeech_lstm.microspeech_lstm import MicroSpeechLSTM
2627
from executorch.examples.nxp.models.mlperf_tiny import (AnomalyDetection, KeywordSpotting, ImageClassification,
2728
VisualWakeWords)
2829
from executorch.exir import ExecutorchBackendConfig
@@ -90,6 +91,7 @@ def get_model_and_inputs_from_name(model_name: str):
9091
"keyword_spotting": KeywordSpotting,
9192
"image_classification": ImageClassification,
9293
"anomaly_detection": AnomalyDetection,
94+
"microspeech_lstm": MicroSpeechLSTM,
9395
}
9496

9597

examples/nxp/models/microspeech_lstm/__init__.py

Whitespace-only changes.
768 KB
Binary file not shown.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import itertools
7+
import os
8+
from typing import Iterator
9+
10+
import torch
11+
from torch import nn
12+
from torch.utils.data import DataLoader
13+
14+
from executorch.examples.models import model_base
15+
16+
17+
class MicroSpeechLSTM(model_base.EagerModelBase):
18+
19+
def __init__(self):
20+
self._weights_file = os.path.join(os.path.dirname(__file__), 'microspeech_lstm_best.pth')
21+
calibration_dataset_path = os.path.join(os.path.dirname(__file__), 'calibration_data.pt')
22+
self._calibration_dataset = torch.load(calibration_dataset_path, map_location=torch.device('cpu'))
23+
24+
@staticmethod
25+
def _collate_fn(data: list[tuple]):
26+
data, labels = zip(*data)
27+
return torch.stack(list(data)), torch.tensor(list(labels))
28+
29+
def get_eager_model(self) -> torch.nn.Module:
30+
lstm_module = MicroSpeechLSTMModule()
31+
lstm_module.load_state_dict(torch.load(self._weights_file, weights_only=True, map_location=torch.device('cpu')))
32+
33+
return lstm_module
34+
35+
def get_example_inputs(self) -> tuple[torch.Tensor]:
36+
sample = self._calibration_dataset[0]
37+
38+
return (sample[0].unsqueeze(0),) # use only sample data, not label
39+
40+
def get_calibration_inputs(self, batch_size: int = 1) -> Iterator[tuple[torch.Tensor]]:
41+
"""
42+
Get LSTM's calibration input. Batch size is ignored and set to 1 because hidden state
43+
has to be initialized to the size of batch.
44+
45+
:param batch_size: Ignored.
46+
:return: Iterator with input calibration dataset samples.
47+
"""
48+
data_loader = DataLoader(self._calibration_dataset, batch_size=1)
49+
return itertools.starmap(lambda data, label: (data,), iter(data_loader))
50+
51+
52+
class MicroSpeechLSTMModule(nn.Module):
53+
54+
def __init__(self, input_size=80, output_size=3, features_size=128):
55+
super(MicroSpeechLSTMModule, self).__init__()
56+
self.lstm = nn.LSTM(input_size, features_size, batch_first=True)
57+
self.linear = nn.Linear(features_size, output_size)
58+
59+
def forward(self, x):
60+
_, (x, _) = self.lstm(x)
61+
x = x.squeeze(0)
62+
x = self.linear(x)
63+
x = nn.functional.softmax(x, dim=1)
64+
return x
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3ffab80556d0e543876479d4de245d76 microspeech_lstm_best.pth
424 KB
Binary file not shown.

0 commit comments

Comments
 (0)