11# RUN: %PYTHON %s | FileCheck %s
22
3+ import numpy as np
34from mlir .ir import *
45from mlir .dialects import quant
56
@@ -18,21 +19,27 @@ def test_type_hierarchy():
1819 any = Type .parse ("!quant.any<i8<-8:7>:f32>" )
1920 uniform = Type .parse ("!quant.uniform<i8<-8:7>:f32, 0.99872:127>" )
2021 per_axis = Type .parse ("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )
22+ sub_channel = Type .parse (
23+ "!quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>" )
2124 calibrated = Type .parse ("!quant.calibrated<f32<-0.998:1.2321>>" )
2225
2326 assert not quant .QuantizedType .isinstance (i8 )
2427 assert quant .QuantizedType .isinstance (any )
2528 assert quant .QuantizedType .isinstance (uniform )
2629 assert quant .QuantizedType .isinstance (per_axis )
30+ assert quant .QuantizedType .isinstance (sub_channel )
2731 assert quant .QuantizedType .isinstance (calibrated )
2832
2933 assert quant .AnyQuantizedType .isinstance (any )
3034 assert quant .UniformQuantizedType .isinstance (uniform )
3135 assert quant .UniformQuantizedPerAxisType .isinstance (per_axis )
36+ assert quant .UniformQuantizedSubChannelType .isinstance (sub_channel )
3237 assert quant .CalibratedQuantizedType .isinstance (calibrated )
3338
3439 assert not quant .AnyQuantizedType .isinstance (uniform )
3540 assert not quant .UniformQuantizedType .isinstance (per_axis )
41+ assert not quant .UniformQuantizedType .isinstance (sub_channel )
42+ assert not quant .UniformQuantizedPerAxisType .isinstance (sub_channel )
3643
3744
3845# CHECK-LABEL: TEST: test_any_quantized_type
@@ -121,6 +128,45 @@ def test_uniform_per_axis_type():
121128 assert per_axis == Type .parse ("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )
122129
123130
131+ # CHECK-LABEL: TEST: test_uniform_sub_channel_type
132+ @run
133+ def test_uniform_sub_channel_type ():
134+ with Context ():
135+ i8 = IntegerType .get_signless (8 )
136+ f32 = F32Type .get ()
137+ sub_channel = quant .UniformQuantizedSubChannelType .get (
138+ quant .QuantizedType .FLAG_SIGNED ,
139+ i8 ,
140+ f32 ,
141+ DenseElementsAttr .get (np .asarray (
142+ [2.0 , 3.0 , 4.0 , 5.0 ], np .float32 ).reshape (2 , 2 )),
143+ DenseElementsAttr .get (np .asarray (
144+ [10 , 20 , 30 , 40 ], np .int8 ).reshape (2 , 2 )),
145+ [0 , 1 ], [1 , 2 ],
146+ storage_type_min = quant .QuantizedType .default_minimum_for_integer (
147+ is_signed = True , integral_width = 8
148+ ),
149+ storage_type_max = quant .QuantizedType .default_maximum_for_integer (
150+ is_signed = True , integral_width = 8
151+ ),
152+ )
153+
154+ # CHECK: quantized dimensions: [0, 1]
155+ print (f"quantized dimensions: { sub_channel .quantized_dimensions } " )
156+ # CHECK: block sizes: [1, 2]
157+ print (f"block sizes: { sub_channel .block_sizes } " )
158+ # CHECK: scales: {{\[}}[2. 3.]
159+ # CHECK: [4. 5.]]
160+ print (f"scales: { np .asarray (sub_channel .scales )} " )
161+ # CHECK: zero-points: {{\[}}[10 20]
162+ # CHECK: [30 40]]
163+ print (f"zero-points: { np .asarray (sub_channel .zero_points )} " )
164+ # CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
165+ print (sub_channel )
166+ assert sub_channel == Type .parse (
167+ "!quant.uniform<i8:f32:{0:1,1:2},{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>" )
168+
169+
124170# CHECK-LABEL: TEST: test_calibrated_type
125171@run
126172def test_calibrated_type ():
0 commit comments