Skip to content

Commit b587ea7

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add mobilenet_v2 and resnet50 to the examples
Summary: As titled. Both from TorchVision. Differential Revision: D66132529
1 parent a0ac820 commit b587ea7

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import logging
10+
11+
import torch
12+
13+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
14+
15+
16+
from executorch.backends.cadence.aot.export_example import export_model
17+
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
18+
19+
20+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21+
logging.basicConfig(level=logging.INFO, format=FORMAT)
22+
23+
24+
if __name__ == "__main__":
25+
26+
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
27+
model.eval()
28+
example_inputs = (torch.randn(1, 3, 64, 64),)
29+
30+
export_model(model, example_inputs)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import logging
10+
11+
import torch
12+
13+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
14+
15+
16+
from executorch.backends.cadence.aot.export_example import export_model
17+
from torchvision.models import resnet50, ResNet50_Weights
18+
19+
20+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21+
logging.basicConfig(level=logging.INFO, format=FORMAT)
22+
23+
24+
if __name__ == "__main__":
25+
26+
model = resnet50(weights=ResNet50_Weights.DEFAULT)
27+
model.eval()
28+
example_inputs = (torch.randn(1, 3, 64, 64),)
29+
30+
export_model(model, example_inputs)

0 commit comments

Comments
 (0)