@@ -40,13 +40,17 @@ def Quant_Dialect : Dialect {
4040 encodes the necessary information for (lossy) round-trip conversion between
4141 an expressed and a stored value.
4242
43- The `quant.uniform` type has two variants: per-layer quantization and
44- per-channel (or per-axis) quantization. In per-layer quantization, the
45- quantization information affects an entire tensor uniformly. Conversely, in
46- per-channel quantization, the data type encodes the specific tensor axis
47- that serves as the channel and includes quantization information for each
48- individual channel within the tensor. Below are the specific syntactic and
49- semantic considerations for each modality.
43+ The `quant.uniform` type has three variants: per-layer quantization,
44+ per-channel (or per-axis) quantization, and sub-channel (or blockwize)
45+ quantization. In per-layer quantization, the quantization information
46+ affects an entire tensor uniformly. Conversely, in per-channel
47+ quantization, the data type encodes the specific tensor axis that serves
48+ as the channel and includes quantization information for each individual
49+ channel within the tensor. Sub-channel quantization is a generalization
50+ of per-tensor and per-channel quantization, where the quantization
51+ parameters are defined for blocks of elements along one or more
52+ dimensions of the tensor. Below are the specific syntactic and semantic
53+ considerations for each modality.
5054
5155
5256 ### Per-layer quantization
@@ -145,7 +149,7 @@ def Quant_Dialect : Dialect {
145149 ```
146150 // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
147151 // floats. Dimension 1 of the tensor acts as the channel dimension. Its
148- // size 3 matches the number of provided scale values. Tensor elemenets at
152+ // size 3 matches the number of provided scale values. Tensor elements at
149153 // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
150154 // 5.0, respectively.
151155 tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
@@ -159,6 +163,72 @@ def Quant_Dialect : Dialect {
159163 tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
160164 ```
161165
166+ ### Sub-channel quantization
167+
168+ Sub-channel quantization, also known as blockwise quantization, provides
169+ finer-grained control than per-tensor or per-channel quantization. It
170+ divides a tensor into blocks of elements, each with its own quantization
171+ parameters (scale and zero point). This is particularly useful when
172+ different regions of a tensor exhibit distinct value ranges.
173+
174+ The `!quant.uniform` type represents sub-channel quantization with the
175+ following syntax:
176+
177+ ```
178+ `!quant.uniform` `<`
179+ storedType (`<` storageMin `:` storageMax `>`)? `:`
180+ expressedType `:` blockSizeInfo
181+ scaleZeroTensor `>`
182+
183+ blockSizeInfo ::= `{` `}` | `{` axisBlock (`,` axisBlock)*)? `}`
184+ axisBlock ::= axis `:` blockSize
185+ scaleZeroTensor ::= scaleZeroDenseExp | scaleZeroList
186+ scaleZeroDenseExp ::= `{` scaleZeroTensor (`,` scaleZeroTensor)* `}`
187+ scaleZeroList ::= scaleZero (`,` scaleZero)*
188+ scaleZero ::= scale (`:` zeroPoint)?
189+
190+ scaleZeroTensor ::= scale-zero-dense-exp | scale-zero-list
191+ scale-zero-dense-exp ::= `{` scale-zero-tensor (`,` scale-zero-tensor)* `}`
192+ scale-zero-list ::= scale (`:` zeroPoint)? (`,` scale (`:` zeroPoint)?)*
193+ ```
194+
195+ The `blockSize` field specifies the size of the blocks along dimension
196+ `axis` of the tensor. The `scale` and `zeroPoint` fields specify the
197+ quantization parameters for a particular block. Specifically, the tensor
198+ element at position [i0...iN] uses
199+ `scaleZeroTensor[i/blockSize0...i/blockSizeN].scale` and
200+ `scaleZeroTensor[i/blockSize0...i/blockSizeN].zeroPoint` as scale
201+ and zeroPoint respectively.
202+
203+ Here are some examples:
204+
205+ ```
206+ // A 3x4 tensor of i8 values representing f32 values, quantized
207+ // along axis-0 and axis-1 with block sizes 1 and 2,
208+ // respectively. As a result, the shape of the scales (or zero-points) will
209+ // be `[3,4]/[1,2] = [3,2]`, which essentially represents the number of
210+ // blocks along each axis. Tensor elements at positions
211+ // [0][0] and [0][1] use scale `s00` and zero point `z00`,
212+ // [0][2] and [0][3] use scale `s01` and zero point `z01`,
213+ // [1][0] and [1][1] use scale `s10` and zero point `z10`,
214+ // [1][2] and [1][3] use scale `s11` and zero point `z11`,
215+ // [2][0] and [2][1] use scale `s20` and zero point `z20`,
216+ // [2][2] and [2][3] use scale `s21` and zero point `z21`,
217+ tensor<3x4x!quant.uniform<i8:f32:{0:1, 1:2},
218+ {{s00:z00, s01:z01}, {s10:z10,s11:z11}, {s20:z20,s21:z21}}>>
219+
220+ // A 2D dynamically sized tensor contains u16 values
221+ // representing f32 values. Since the shape of the quantization
222+ // parameters (i.e. scales and zero-points) is given as [2,2] and
223+ // the blocks-sizes are given as [1,2], the shape of the tensor is expected
224+ // to be [2,4] (= [2,2] * [1,2]) at runtime. Tensor elements at positions
225+ // [0][0] and [0][1] use scale `s00` and zero point `z00`,
226+ // [0][2] and [0][3] use scale `s01` and zero point `z01`,
227+ // [1][0] and [1][1] use scale `s10` and zero point `z10`,
228+ // [1][2] and [1][3] use scale `s11` and zero point `z11`,
229+ tensor<?x?x!quant.uniform<u16:f32:{0:1, 1:2},
230+ {{s00:z00, s01:z01}, {s10:z10,s11:z11}}>>
231+ ```
162232
163233 ## Per-axis quantization integrity
164234
@@ -170,7 +240,7 @@ def Quant_Dialect : Dialect {
170240 respected in any context in which the `!quant.uniform` data type is used,
171241 such as the header of a `func.func` op, or the input of an arithmetic
172242 operation.
173-
243+
174244 - A quantized type with per-channel quantization information must be the
175245 element type of a tensor container type, and may not occur directly as
176246 the data type of a scalar value.
@@ -209,6 +279,110 @@ def Quant_Dialect : Dialect {
209279 // Correct. The quantized type now includes 3 scale values, matching the
210280 // size of dimension 1 of the result tensor.
211281 %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
282+
283+ ## Sub-channel quantization integrity
284+
285+ When type `!quant.uniform` contains sub-channel quantization information,
286+ the following rules are enforced. For efficiency, these rules are actively
287+ enforced by the verifiers of `quant` dialect ops, but they must be
288+ respected in any context in which the `!quant.uniform` data type is used,
289+ such as the header of a `func.func` op, or the input of an arithmetic
290+ operation.
291+
292+ - A quantized type with sub-channel quantization information must be the
293+ element type of a tensor container type, and may not occur directly as
294+ the data type of a scalar value.
295+
296+ ```
297+ // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
298+ // scalar type.
299+ %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>
300+
301+ // Correct. Type `!quant.uniform` with sub-channel quantization is wrapped
302+ // in a `tensor` type.
303+ %result = quant.qcast %input : tensor<2x2xf32> to
304+ tensor<2x2x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
305+ ```
306+
307+ - The tensor containing the sub-channel quantized type must be ranked.
308+
309+ ```
310+ // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
311+ // unranked tensor type.
312+ %result = quant.qcast %input : tensor<*xf32> to
313+ tensor<*x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
314+ ```
315+
316+ - The axis for which a block size is specified should be valid for a tensor
317+ of a given rank. Block sizes can be specified for a subset of axes.
318+ Any unspecified block size for an axis i defaults to the tensor dimension
319+ size of that axis (shape(tensor)[i]).
320+
321+ ```
322+ // Incorrect. The block-size is specified for axis 2 which is greater than
323+ // the rank of the tensor.
324+ %result = quant.qcast %input : tensor<2x2xf32> to
325+ tensor<2x2x!quant.uniform<i8:f32:{2:1, 1:2}, {{1.0}, {2.0}}>>
326+
327+ // Incorrect. The block-size is specified for a negative axis.
328+ %result = quant.qcast %input : tensor<2x2xf32> to
329+ tensor<2x2x!quant.uniform<i8:f32:{-1:1, 1:2}, {{1.0}, {2.0}}>>
330+
331+ // Correct. The block size for axis 1 is skipped which should be assumed as
332+ // 2, the dim-size of tensor at axis 1.
333+ %result = quant.qcast %input : tensor<6x2xf32> to
334+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {3.0}}>>
335+
336+ // Correct. The block size for all the axes are skipped making the
337+ // sub-channel type essentially a per-tensor type.
338+ %result = quant.qcast %input : tensor<6x2xf32> to
339+ tensor<6x2x!quant.uniform<i8:f32:{}, {{1.0}}>>
340+ ```
341+
342+ - Block size for a particular axis should be a positive integer and should
343+ be less than the dimension size of the tensor along that axis.
344+
345+ ```
346+ // Incorrect. The block size for axis 0 is -1.
347+ %result = quant.qcast %input : tensor<6x2xf32> to
348+ tensor<6x2x!quant.uniform<i8:f32:{0:-1}, {{1.0, 2.0}}>>
349+
350+ // Incorrect. The block size for axis 0 is 8 which is greater than the
351+ // dimension size of tensor at axis 0 (which is 6).
352+ %result = quant.qcast %input : tensor<6x2xf32> to
353+ tensor<6x2x!quant.uniform<i8:f32:{0:8}, {{1.0, 2.0}}>>
354+
355+ // Correct. The block size for axis 0 is now 3.
356+ %result = quant.qcast %input : tensor<6x2xf32> to
357+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
358+ ```
359+
360+ - shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for
361+ axis i in [0, 1, ..., rank(tensor)-1]].
362+
363+ ```
364+ // Incorrect. The block size for axis 0 is 4 and the corresponding
365+ // dimension size is 6 and 6 % 4 != 0.
366+ %result = quant.qcast %input : tensor<6x2xf32> to
367+ tensor<6x2x!quant.uniform<i8:f32:{0:4}, {{1.0, 2.0}}>>
368+
369+ // Correct. The block size for axis 0 is now 3 making 6 % 3 = 0.
370+ %result = quant.qcast %input : tensor<6x2xf32> to
371+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
372+ ```
373+
374+ - shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes.
375+
376+ ```
377+ // Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but
378+ // shape(scales) is [1,2] which is not equal to [6,2]/[3,2].
379+ %result = quant.qcast %input : tensor<6x2xf32> to
380+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0, 2.0}}>>
381+
382+ // Correct. shape(tensor) = [6,2], blockSizes = [3,2], and
383+ // shape(scales) equals [6,2]/[3,2].
384+ %result = quant.qcast %input : tensor<6x2xf32> to
385+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
212386 ```
213387 }];
214388 let cppNamespace = "::mlir::quant";
0 commit comments