@@ -49,6 +49,32 @@ void test_dtype() {
4949 EXPECT_TENSOR_EQ (out, expected);
5050}
5151
52+ template <ScalarType INPUT_DTYPE>
53+ void test_input_dtype () {
54+ TensorFactory<INPUT_DTYPE> tf_input;
55+
56+ Tensor input = tf_input.full ({3 , 5 }, 4 );
57+ double scale = 0.5 ;
58+ int64_t zero_point = 108 ;
59+ int64_t quant_min = 0 ;
60+ int64_t quant_max = 127 ;
61+
62+ TensorFactory<ScalarType::Char> tfo;
63+ Tensor out = tfo.zeros ({3 , 5 });
64+ // 4 / 0.5 + 108 = 116
65+ Tensor expected = tfo.full ({3 , 5 }, 116 );
66+ quantize_per_tensor_out (
67+ input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
68+
69+ EXPECT_TENSOR_EQ (out, expected);
70+ }
71+
72+ TEST (OpQuantizeOutTest, AllInputDtypesSupported) {
73+ test_input_dtype<ScalarType::Float>();
74+ test_input_dtype<ScalarType::Half>();
75+ test_input_dtype<ScalarType::Double>();
76+ }
77+
5278TEST (OpQuantizeOutTest, AllDtypesSupported) {
5379 test_dtype<ScalarType::Byte>();
5480 test_dtype<ScalarType::Char>();
@@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) {
5884 test_dtype<ScalarType::Int>();
5985}
6086
87+ TEST (OpQuantizeOutTest, DoubleInputTest) {
88+ TensorFactory<ScalarType::Double> tf_double;
89+
90+ // Test with a more complex value that might have precision differences
91+ Tensor input = tf_double.full ({2 , 3 }, 3.14159265359 );
92+ double scale = 0.01 ;
93+ int64_t zero_point = -100 ;
94+ int64_t quant_min = 0 ;
95+ int64_t quant_max = 255 ;
96+
97+ TensorFactory<ScalarType::Byte> tfo;
98+ Tensor out = tfo.zeros ({2 , 3 });
99+ // 3.14159265359 / 0.01 - 100 = 214.159265359
100+ Tensor expected = tfo.full ({2 , 3 }, 214 );
101+ quantize_per_tensor_out (
102+ input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
103+
104+ EXPECT_TENSOR_EQ (out, expected);
105+ }
106+
107+ TEST (OpQuantizeOutTest, HalfInputTest) {
108+ TensorFactory<ScalarType::Half> tf_half;
109+
110+ Tensor input = tf_half.full ({2 , 3 }, 2.5 );
111+ double scale = 0.5 ;
112+ int64_t zero_point = 10 ;
113+ int64_t quant_min = -128 ;
114+ int64_t quant_max = 127 ;
115+
116+ TensorFactory<ScalarType::Char> tfo;
117+ Tensor out = tfo.zeros ({2 , 3 });
118+ // 2.5 / 0.5 + 10 = 15
119+ Tensor expected = tfo.full ({2 , 3 }, 15 );
120+ quantize_per_tensor_out (
121+ input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
122+
123+ EXPECT_TENSOR_EQ (out, expected);
124+ }
125+
61126TEST (OpQuantizeOutTest, TensorArgOverload) {
62127 TensorFactory<ScalarType::Float> tf_float;
63128 TensorFactory<ScalarType::Double> tf_double;
0 commit comments