33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- import unittest
6+
7+ from typing import Tuple
78
89import torch
910from executorch .backends .arm ._passes .convert_to_clamp import ConvertToClampPass
1011
1112from executorch .backends .arm .test import common
12- from executorch .backends .arm .test .tester .arm_tester import ArmTester
13+ from executorch .backends .arm .test .tester .test_pipeline import PassPipeline
1314
14- from executorch . backends . xnnpack . test . tester . tester import RunPasses
15+ input_t = Tuple [ torch . Tensor ] # Input x
1516
1617
1718class HardTanh (torch .nn .Module ):
19+ test_data = {"rand" : (torch .rand (1 , 64 , 64 , 3 ),)}
20+
1821 def __init__ (self ):
1922 super ().__init__ ()
2023
@@ -23,11 +26,10 @@ def __init__(self):
2326 def forward (self , x ):
2427 return self .hardtanh (x )
2528
26- def get_inputs (self ):
27- return (torch .rand (1 , 64 , 64 , 3 ),)
28-
2929
3030class ReLU (torch .nn .Module ):
31+ test_data = {"rand" : (torch .rand (1 , 64 , 64 , 3 ),)}
32+
3133 def __init__ (self ):
3234 super ().__init__ ()
3335
@@ -36,45 +38,55 @@ def __init__(self):
3638 def forward (self , x ):
3739 return self .relu (x )
3840
39- def get_inputs (self ):
40- return (torch .rand (1 , 64 , 64 , 3 ),)
41-
42-
43- class TestConvertToClampPass (unittest .TestCase ):
44- """
45- Tests the ConvertToClampPass which converts hardtanh.default and relu.default to clamp.default
46- """
47-
48- def test_tosa_MI_hardtahn (self ):
49- module = HardTanh ()
50- test_pass_stage = RunPasses ([ConvertToClampPass ])
51- (
52- ArmTester (
53- module ,
54- example_inputs = module .get_inputs (),
55- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" ),
56- )
57- .export ()
58- .to_edge ()
59- .check (["executorch_exir_dialects_edge__ops_aten_hardtanh_default" ])
60- .run_passes (test_pass_stage )
61- .check (["executorch_exir_dialects_edge__ops_aten_clamp_default" ])
62- .check_not (["executorch_exir_dialects_edge__ops_aten_hardtanh_default" ])
63- )
64-
65- def test_tosa_MI_relu (self ):
66- module = ReLU ()
67- test_pass_stage = RunPasses ([ConvertToClampPass ])
68- (
69- ArmTester (
70- module ,
71- example_inputs = module .get_inputs (),
72- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" ),
73- )
74- .export ()
75- .to_edge ()
76- .check (["executorch_exir_dialects_edge__ops_aten_relu_default" ])
77- .run_passes (test_pass_stage )
78- .check (["executorch_exir_dialects_edge__ops_aten_clamp_default" ])
79- .check_not (["executorch_exir_dialects_edge__ops_aten_relu_default" ])
80- )
41+
42+ """
43+ Tests the ConvertToClampPass which converts hardtanh.default and relu.default to clamp.default
44+ """
45+
46+
47+ @common .parametrize ("test_data" , HardTanh .test_data )
48+ def test_tosa_MI_hardtahn (test_data : input_t ):
49+ module = HardTanh ()
50+ op_checks_before_pass = {
51+ "executorch_exir_dialects_edge__ops_aten_hardtanh_default" : 1 ,
52+ }
53+ op_checks_after_pass = {
54+ "executorch_exir_dialects_edge__ops_aten_clamp_default" : 1 ,
55+ }
56+ op_checks_not_after_pass = [
57+ "executorch_exir_dialects_edge__ops_aten_hardtanh_default" ,
58+ ]
59+ pipeline = PassPipeline [input_t ](
60+ module ,
61+ test_data ,
62+ quantize = False ,
63+ ops_before_pass = op_checks_before_pass ,
64+ ops_after_pass = op_checks_after_pass ,
65+ ops_not_after_pass = op_checks_not_after_pass ,
66+ pass_list = [ConvertToClampPass ],
67+ )
68+ pipeline .run ()
69+
70+
71+ @common .parametrize ("test_data" , ReLU .test_data )
72+ def test_tosa_MI_relu (test_data : input_t ):
73+ module = ReLU ()
74+ op_checks_before_pass = {
75+ "executorch_exir_dialects_edge__ops_aten_relu_default" : 1 ,
76+ }
77+ op_checks_after_pass = {
78+ "executorch_exir_dialects_edge__ops_aten_clamp_default" : 1 ,
79+ }
80+ op_checks_not_after_pass = [
81+ "executorch_exir_dialects_edge__ops_aten_relu_default" ,
82+ ]
83+ pipeline = PassPipeline [input_t ](
84+ module ,
85+ test_data ,
86+ quantize = False ,
87+ ops_before_pass = op_checks_before_pass ,
88+ ops_after_pass = op_checks_after_pass ,
89+ ops_not_after_pass = op_checks_not_after_pass ,
90+ pass_list = [ConvertToClampPass ],
91+ )
92+ pipeline .run ()
0 commit comments