3535 "rank_4_no_max" : lambda : (torch .rand (1 , 10 , 10 , 1 ) - 3 , - 3.3 , None ),
3636}
3737
38+ test_data_suite_int32 = {
39+ "int32_rank2" : lambda : (torch .randint (- 50 , 50 , (2 , 3 ), dtype = torch .int32 ), - 10 , 10 ),
40+ "int32_rank3_no_min" : lambda : (
41+ torch .randint (- 100 , 100 , (1 , 3 , 3 ), dtype = torch .int32 ),
42+ None ,
43+ 25 ,
44+ ),
45+ "int32_rank3_no_max" : lambda : (
46+ torch .randint (- 100 , 100 , (1 , 3 , 3 ), dtype = torch .int32 ),
47+ - 25 ,
48+ None ,
49+ ),
50+ "int32_rank4_large_range" : lambda : (
51+ torch .randint (- 200 , 200 , (1 , 2 , 4 , 4 ), dtype = torch .int32 ),
52+ torch .iinfo (torch .int32 ).min ,
53+ torch .iinfo (torch .int32 ).max ,
54+ ),
55+ }
56+
3857
3958class Clamp (torch .nn .Module ):
4059 def __init__ (
@@ -53,7 +72,6 @@ def forward(self, x):
5372
5473@common .parametrize ("test_data" , test_data_suite )
5574def test_clamp_tosa_FP (test_data ):
56-
5775 input_tensor , min_val , max_val = test_data ()
5876 model = Clamp (min_val , max_val )
5977
@@ -69,7 +87,6 @@ def test_clamp_tosa_FP(test_data):
6987
7088@common .parametrize ("test_data" , test_data_suite )
7189def test_clamp_tosa_INT (test_data ):
72-
7390 input_tensor , min_val , max_val = test_data ()
7491 model = Clamp (min_val , max_val )
7592
@@ -84,6 +101,22 @@ def test_clamp_tosa_INT(test_data):
84101 pipeline .run ()
85102
86103
104+ @common .parametrize ("test_data" , test_data_suite_int32 )
105+ def test_clamp_tosa_INT_int32_inputs (test_data ):
106+ input_tensor , min_val , max_val = test_data ()
107+ model = Clamp (min_val , max_val )
108+
109+ pipeline = TosaPipelineINT [input_t ](
110+ model ,
111+ (input_tensor ,),
112+ aten_op ,
113+ exir_op ,
114+ )
115+ pipeline .change_args ("run_method_and_compare_outputs" , qtol = 1 )
116+ pipeline .pop_stage ("quantize" )
117+ pipeline .run ()
118+
119+
87120@common .parametrize ("test_data" , test_data_suite )
88121def test_clamp_tosa_INT_a16w8 (test_data ):
89122 """Test clamp operation with int16 I/O quantization for TOSA INT."""
@@ -103,7 +136,6 @@ def test_clamp_tosa_INT_a16w8(test_data):
103136@common .parametrize ("test_data" , test_data_suite )
104137@common .XfailIfNoCorstone300
105138def test_clamp_u55_INT (test_data ):
106-
107139 input_tensor , min_val , max_val = test_data ()
108140 model = Clamp (min_val , max_val )
109141
@@ -140,7 +172,6 @@ def test_clamp_16a8w_u55_INT16(test_data):
140172@common .parametrize ("test_data" , test_data_suite )
141173@common .XfailIfNoCorstone320
142174def test_clamp_u85_INT (test_data ):
143-
144175 input_tensor , min_val , max_val = test_data ()
145176 model = Clamp (min_val , max_val )
146177
0 commit comments