11# RUN: %PYTHON %s | FileCheck %s
22
3+ import numpy as np
34from mlir .ir import *
45from mlir .dialects import quant
56
@@ -18,21 +19,28 @@ def test_type_hierarchy():
1819 any = Type .parse ("!quant.any<i8<-8:7>:f32>" )
1920 uniform = Type .parse ("!quant.uniform<i8<-8:7>:f32, 0.99872:127>" )
2021 per_axis = Type .parse ("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )
22+ sub_channel = Type .parse (
23+ "!quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"
24+ )
2125 calibrated = Type .parse ("!quant.calibrated<f32<-0.998:1.2321>>" )
2226
2327 assert not quant .QuantizedType .isinstance (i8 )
2428 assert quant .QuantizedType .isinstance (any )
2529 assert quant .QuantizedType .isinstance (uniform )
2630 assert quant .QuantizedType .isinstance (per_axis )
31+ assert quant .QuantizedType .isinstance (sub_channel )
2732 assert quant .QuantizedType .isinstance (calibrated )
2833
2934 assert quant .AnyQuantizedType .isinstance (any )
3035 assert quant .UniformQuantizedType .isinstance (uniform )
3136 assert quant .UniformQuantizedPerAxisType .isinstance (per_axis )
37+ assert quant .UniformQuantizedSubChannelType .isinstance (sub_channel )
3238 assert quant .CalibratedQuantizedType .isinstance (calibrated )
3339
3440 assert not quant .AnyQuantizedType .isinstance (uniform )
3541 assert not quant .UniformQuantizedType .isinstance (per_axis )
42+ assert not quant .UniformQuantizedType .isinstance (sub_channel )
43+ assert not quant .UniformQuantizedPerAxisType .isinstance (sub_channel )
3644
3745
3846# CHECK-LABEL: TEST: test_any_quantized_type
@@ -121,6 +129,47 @@ def test_uniform_per_axis_type():
121129 assert per_axis == Type .parse ("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )
122130
123131
132+ # CHECK-LABEL: TEST: test_uniform_sub_channel_type
133+ @run
134+ def test_uniform_sub_channel_type ():
135+ with Context ():
136+ i8 = IntegerType .get_signless (8 )
137+ f32 = F32Type .get ()
138+ sub_channel = quant .UniformQuantizedSubChannelType .get (
139+ quant .QuantizedType .FLAG_SIGNED ,
140+ i8 ,
141+ f32 ,
142+ DenseElementsAttr .get (
143+ np .asarray ([2.0 , 3.0 , 4.0 , 5.0 ], np .float32 ).reshape (2 , 2 )
144+ ),
145+ DenseElementsAttr .get (np .asarray ([10 , 20 , 30 , 40 ], np .int8 ).reshape (2 , 2 )),
146+ [0 , 1 ],
147+ [1 , 2 ],
148+ storage_type_min = quant .QuantizedType .default_minimum_for_integer (
149+ is_signed = True , integral_width = 8
150+ ),
151+ storage_type_max = quant .QuantizedType .default_maximum_for_integer (
152+ is_signed = True , integral_width = 8
153+ ),
154+ )
155+
156+ # CHECK: quantized dimensions: [0, 1]
157+ print (f"quantized dimensions: { sub_channel .quantized_dimensions } " )
158+ # CHECK: block sizes: [1, 2]
159+ print (f"block sizes: { sub_channel .block_sizes } " )
160+ # CHECK: scales: {{\[}}[2. 3.]
161+ # CHECK: [4. 5.]]
162+ print (f"scales: { np .asarray (sub_channel .scales )} " )
163+ # CHECK: zero-points: {{\[}}[10 20]
164+ # CHECK: [30 40]]
165+ print (f"zero-points: { np .asarray (sub_channel .zero_points )} " )
166+ # CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
167+ print (sub_channel )
168+ assert sub_channel == Type .parse (
169+ "!quant.uniform<i8:f32:{0:1,1:2},{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"
170+ )
171+
172+
124173# CHECK-LABEL: TEST: test_calibrated_type
125174@run
126175def test_calibrated_type ():
0 commit comments