1- # Copyright 2024 Arm Limited and/or its affiliates.
2- # All rights reserved.
1+ # Copyright 2024-2025 Arm Limited and/or its affiliates.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
65import itertools
7- import unittest
6+
7+ from typing import Tuple
88
99import torch
1010from executorch .backends .arm .quantizer import is_annotated
11- from executorch .backends .arm .test import common
12- from executorch . backends . arm . test . tester . arm_tester import ArmTester
11+ from executorch .backends .arm .test . tester . test_pipeline import TosaPipelineBI
12+
1313from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
1414
1515
16+ input_t1 = Tuple [torch .Tensor ] # Input x
17+
18+
1619class SingleOpModel (torch .nn .Module ):
1720 def __init__ (self , op , example_input , ** op_kwargs ) -> None :
1821 super ().__init__ ()
@@ -27,69 +30,74 @@ def example_inputs(self):
2730 return self ._example_input
2831
2932
30- class TestGenericAnnotator (unittest .TestCase ):
31- def check_annotation (self , model ):
32- tester = ArmTester (
33- model ,
34- model .example_inputs (),
35- common .get_tosa_compile_spec ("TOSA-0.80+BI" ),
36- )
37- quant_model = tester .quantize ().get_artifact ()
38- partitions = get_source_partitions (quant_model .graph , [model .op ])
39- partitions = list (itertools .chain .from_iterable (partitions .values ()))
40-
41- assert len (partitions ) == 1
42- partition = partitions [0 ]
43- assert all (is_annotated (node ) for node in partition .nodes )
44-
45- def test_squeeze (self ):
46- self .check_annotation (SingleOpModel (torch .squeeze , (torch .rand (8 , 8 , 1 ),)))
47- self .check_annotation (SingleOpModel (torch .squeeze_copy , (torch .rand (8 , 8 , 1 ),)))
48-
49- def test_unsqueeze (self ):
50- self .check_annotation (
51- SingleOpModel (torch .unsqueeze , (torch .rand (8 , 8 ),), dim = 0 )
52- )
53- self .check_annotation (
54- SingleOpModel (torch .unsqueeze_copy , (torch .rand (8 , 8 ),), dim = 0 )
55- )
56-
57- def test_reshape (self ):
58- self .check_annotation (
59- SingleOpModel (torch .reshape , (torch .randn (8 , 8 ),), shape = (64 ,)),
60- )
61-
62- def test_view (self ):
63- self .check_annotation (
64- SingleOpModel (torch .view_copy , (torch .randn (4 , 4 ),), size = (2 , 8 )),
65- )
66-
67- def test_slice (self ):
68- self .check_annotation (
69- SingleOpModel (torch .slice_copy , (torch .randn (3 , 4 ),)),
70- )
71-
72- def test_transpose (self ):
73- self .check_annotation (
74- SingleOpModel (torch .transpose , (torch .randn (2 , 3 ),), dim0 = 0 , dim1 = 1 ),
75- )
76- self .check_annotation (
77- SingleOpModel (torch .transpose_copy , (torch .randn (2 , 3 ),), dim0 = 0 , dim1 = 1 ),
78- )
79-
80- def test_tile (self ):
81- self .check_annotation (
82- SingleOpModel (torch .tile , (torch .randn (4 , 4 ),), dims = (2 ,)),
83- )
84-
85- def test_flip (self ):
86- self .check_annotation (
87- SingleOpModel (torch .flip , (torch .randn (2 , 4 ),), dims = (0 , 1 )),
88- )
89-
90- def test_concat (self ):
91- self .check_annotation (
92- SingleOpModel (
93- torch .concatenate , ((torch .randn (2 , 3 ), torch .randn (2 , 3 )),), dim = 0
94- ),
95- )
33+ def check_annotation (model ):
34+ pipeline = TosaPipelineBI [input_t1 ](model , model .example_inputs (), [], [])
35+ pipeline .pop_stage ("check_count.exir" )
36+ pipeline .pop_stage ("run_method_and_compare_outputs" )
37+ pipeline .run ()
38+
39+ artifact = pipeline .tester .get_artifact ("Quantize" )
40+
41+ partitions = get_source_partitions (artifact .graph , [model .op ])
42+ partitions = list (itertools .chain .from_iterable (partitions .values ()))
43+
44+ assert len (partitions ) == 1
45+ partition = partitions [0 ]
46+ assert all (is_annotated (node ) for node in partition .nodes )
47+
48+
49+ def test_squeeze ():
50+ check_annotation (SingleOpModel (torch .squeeze , (torch .rand (8 , 8 , 1 ),)))
51+ check_annotation (SingleOpModel (torch .squeeze_copy , (torch .rand (8 , 8 , 1 ),)))
52+
53+
54+ def test_unsqueeze ():
55+ check_annotation (SingleOpModel (torch .unsqueeze , (torch .rand (8 , 8 ),), dim = 0 ))
56+ check_annotation (SingleOpModel (torch .unsqueeze_copy , (torch .rand (8 , 8 ),), dim = 0 ))
57+
58+
59+ def test_reshape ():
60+ check_annotation (
61+ SingleOpModel (torch .reshape , (torch .randn (8 , 8 ),), shape = (64 ,)),
62+ )
63+
64+
65+ def test_view ():
66+ check_annotation (
67+ SingleOpModel (torch .view_copy , (torch .randn (4 , 4 ),), size = (2 , 8 )),
68+ )
69+
70+
71+ def test_slice ():
72+ check_annotation (
73+ SingleOpModel (torch .slice_copy , (torch .randn (3 , 4 ),)),
74+ )
75+
76+
77+ def test_transpose ():
78+ check_annotation (
79+ SingleOpModel (torch .transpose , (torch .randn (2 , 3 ),), dim0 = 0 , dim1 = 1 ),
80+ )
81+ check_annotation (
82+ SingleOpModel (torch .transpose_copy , (torch .randn (2 , 3 ),), dim0 = 0 , dim1 = 1 ),
83+ )
84+
85+
86+ def test_tile ():
87+ check_annotation (
88+ SingleOpModel (torch .tile , (torch .randn (4 , 4 ),), dims = (2 ,)),
89+ )
90+
91+
92+ def test_flip ():
93+ check_annotation (
94+ SingleOpModel (torch .flip , (torch .randn (2 , 4 ),), dims = (0 , 1 )),
95+ )
96+
97+
98+ def test_concat ():
99+ check_annotation (
100+ SingleOpModel (
101+ torch .concatenate , ((torch .randn (2 , 3 ), torch .randn (2 , 3 )),), dim = 0
102+ ),
103+ )
0 commit comments