Skip to content

Commit 057c6be

Browse files
tom-armoscarandersson8218
authored andcommitted
Arm backend: Add VIT test
Change-Id: Icb1923a9b2cc4740ceff0d4daccba1d12ead6e5f Signed-off-by: Oscar Andersson <[email protected]> Co-authored-by: Tom Allsop <[email protected]>
1 parent b809abc commit 057c6be

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import pytest
9+
10+
import torch
11+
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
from torchvision import models, transforms
20+
21+
vit_b_16_model = models.vit_b_16(weights="IMAGENET1K_V1")
22+
vit = vit_b_16_model.eval()
23+
24+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25+
26+
input_tensor = torch.rand(1, 3, 224, 224)
27+
28+
model_inputs = (normalize(input_tensor),)
29+
input_t = Tuple[torch.Tensor]
30+
31+
32+
@pytest.mark.slow
33+
def test_vit_tosa_MI():
34+
pipeline = TosaPipelineMI[input_t](
35+
vit, model_inputs, aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True
36+
)
37+
pipeline.run()
38+
39+
40+
@pytest.mark.slow
41+
def test_vit_tosa_BI():
42+
pipeline = TosaPipelineBI[input_t](
43+
vit,
44+
model_inputs,
45+
aten_op=[],
46+
exir_op=[],
47+
atol=5.0,
48+
qtol=1,
49+
)
50+
51+
pipeline.run()
52+
53+
54+
@pytest.mark.slow
55+
@pytest.mark.xfail(reason="Unsupported transpose")
56+
def test_vit_u55_BI():
57+
pipeline = EthosU55PipelineBI[input_t](
58+
vit,
59+
model_inputs,
60+
aten_ops=[],
61+
exir_ops=[],
62+
run_on_fvp=False,
63+
)
64+
pipeline.run()
65+
66+
67+
@pytest.mark.slow
68+
def test_vit_u85_BI():
69+
pipeline = EthosU85PipelineBI[input_t](
70+
vit,
71+
model_inputs,
72+
aten_ops=[],
73+
exir_ops=[],
74+
run_on_fvp=False,
75+
)
76+
pipeline.run()

0 commit comments

Comments
 (0)