Skip to content

Commit 55ccc62

Browse files
committed
Add support for MobilenetV2 model
1 parent 3c87af3 commit 55ccc62

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

examples/nxp/aot_neutron_compile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from executorch.examples.nxp.models.microspeech_lstm.microspeech_lstm import MicroSpeechLSTM
2727
from executorch.examples.nxp.models.mlperf_tiny import (AnomalyDetection, KeywordSpotting, ImageClassification,
2828
VisualWakeWords)
29+
from executorch.examples.nxp.models.mobilenet_v2 import MobilenetV2
2930
from executorch.exir import ExecutorchBackendConfig
3031
from executorch.extension.export_util import export_to_edge, save_pte_program
3132

@@ -92,6 +93,7 @@ def get_model_and_inputs_from_name(model_name: str):
9293
"image_classification": ImageClassification,
9394
"anomaly_detection": AnomalyDetection,
9495
"microspeech_lstm": MicroSpeechLSTM,
96+
"mobilenetv2": MobilenetV2,
9597
}
9698

9799

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import itertools
7+
from typing import Iterator
8+
9+
import torch
10+
import torchvision
11+
from torch.utils.data import DataLoader
12+
from torchvision import transforms
13+
14+
from executorch.examples.models.mobilenet_v2 import MV2Model
15+
16+
17+
class MobilenetV2(MV2Model):
18+
19+
def get_calibration_inputs(self, batch_size: int = 1) -> Iterator[tuple[torch.Tensor]]:
20+
"""
21+
Returns an iterator for the Imagenette validation dataset, downloading it if necessary.
22+
23+
Args:
24+
batch_size (int): The batch size for the iterator.
25+
26+
Returns:
27+
iterator: An iterator that yields batches of images from the Imagnetette validation dataset.
28+
"""
29+
dataloader = self.get_dataset(batch_size)
30+
31+
# Return the iterator
32+
dataloader_iterable = itertools.starmap(lambda data, label: (data,), iter(dataloader))
33+
34+
# We want approximately 500 samples
35+
batch_count = 500 // batch_size
36+
return itertools.islice(dataloader_iterable, batch_count)
37+
38+
def get_dataset(self, batch_size):
39+
# Define data transformations
40+
data_transforms = transforms.Compose([
41+
transforms.Resize((224, 224)),
42+
transforms.ToTensor(),
43+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
44+
])
45+
46+
dataset = torchvision.datasets.Imagenette(root='./data', split='val', transform=data_transforms, download=True)
47+
dataloader = torch.utils.data.DataLoader(
48+
dataset,
49+
batch_size=batch_size,
50+
shuffle=False,
51+
num_workers=1,
52+
)
53+
return dataloader
54+
55+
56+
def gather_samples_per_class_from_dataloader(dataloader, num_samples_per_class=10) -> list[tuple]:
57+
"""
58+
Gathers a specified number of samples for each class from a DataLoader.
59+
60+
Args:
61+
dataloader (DataLoader): The PyTorch DataLoader object.
62+
num_samples_per_class (int): The number of samples to gather for each class. Defaults to 10.
63+
64+
Returns:
65+
samples: A list of (sample, label) tuples.
66+
"""
67+
68+
if not isinstance(dataloader, DataLoader):
69+
raise TypeError("dataloader must be a torch.utils.data.DataLoader object")
70+
if not isinstance(num_samples_per_class, int) or num_samples_per_class <= 0:
71+
raise ValueError("num_samples_per_class must be a positive integer")
72+
73+
labels = sorted(list(set([label for _, label in dataloader.dataset]))) # Get unique labels from the dataset
74+
samples_per_label = {label: [] for label in labels} # Initialize dictionary
75+
76+
for sample, label in dataloader:
77+
label = label.item()
78+
if len(samples_per_label[label]) < num_samples_per_class:
79+
samples_per_label[label].append((sample, label))
80+
81+
samples = []
82+
83+
for label in labels:
84+
samples.extend(samples_per_label[label])
85+
86+
return samples
87+
88+
89+
def generate_input_samples_file():
90+
model = MobilenetV2()
91+
dataloader = model.get_dataset(batch_size=1)
92+
samples = gather_samples_per_class_from_dataloader(dataloader, num_samples_per_class=2)
93+
94+
torch.save(samples, "calibration_data.pt")
95+
96+
97+
if __name__ == '__main__':
98+
generate_input_samples_file()

0 commit comments

Comments
 (0)