@@ -38,11 +38,11 @@ Type getScalarType(Type inputType) {
3838 return inputType;
3939}
4040
41- // Return the shape of an input value as a list of attributes (static dimensions)
42- // and values (dynamic dimensions). If 'input' is a scalar, an empty list is
43- // returned. If 'input' is a tensor, its shape is returned.
44- SmallVector<OpFoldResult>
45- getScalarOrTensorShape (OpBuilder &builder, Location loc, Value input) {
41+ // Return the shape of an input value as a list of attributes (static
42+ // dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
43+ // list is returned. If 'input' is a tensor, its shape is returned.
44+ SmallVector<OpFoldResult> getScalarOrTensorShape (OpBuilder &builder,
45+ Location loc, Value input) {
4646 if (isa<TensorType>(input.getType ()))
4747 return tensor::getMixedSizes (builder, loc, input);
4848 return {};
@@ -100,16 +100,16 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
100100
101101 // Turn input size into 1D tensor
102102 auto flatShapeType = shape::getExtentTensorType (context, 1 );
103- auto flatInputShape = builder. create <tensor::FromElementsOp>(
104- loc, flatShapeType, inputSize);
103+ auto flatInputShape =
104+ builder. create <tensor::FromElementsOp>( loc, flatShapeType, inputSize);
105105
106106 // Reshape input tensor into 1D
107107 auto inputType = cast<UnrankedTensorType>(input.getType ());
108108 auto elementType = inputType.getElementType ();
109109 auto flatInputType =
110110 RankedTensorType::get ({ShapedType::kDynamic }, elementType);
111- auto flatInput = builder.create <tensor::ReshapeOp>(
112- loc, flatInputType, input, flatInputShape);
111+ auto flatInput = builder.create <tensor::ReshapeOp>(loc, flatInputType, input,
112+ flatInputShape);
113113 return std::make_pair (flatInput, inputShape);
114114}
115115
@@ -135,11 +135,9 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
135135// - inputShape
136136// 1D extent tensor containing the shape of the original unranked input.
137137//
138- std::pair<Value, Value> flattenUnrankedTensorAroundAxis (OpBuilder &builder,
139- Location loc,
140- Value input,
141- int64_t axis,
142- int64_t axisSize) {
138+ std::pair<Value, Value>
139+ flattenUnrankedTensorAroundAxis (OpBuilder &builder, Location loc, Value input,
140+ int64_t axis, int64_t axisSize) {
143141 // Get full tensor shape
144142 auto *context = builder.getContext ();
145143 auto indexType = builder.getIndexType ();
@@ -149,16 +147,20 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
149147 // Get shape and sizes on left and right of axis
150148 auto axisValue = builder.create <arith::ConstantIndexOp>(loc, axis);
151149 auto axisNextValue = builder.create <arith::ConstantIndexOp>(loc, axis + 1 );
152- auto shapeLeft = builder.create <shape::SplitAtOp>(
153- loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
154- .getResult (0 );
155- auto sizeLeft = builder.create <shape::NumElementsOp>(
156- loc, indexType, shapeLeft);
157- auto shapeRight = builder.create <shape::SplitAtOp>(
158- loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
159- .getResult (1 );
160- auto sizeRight = builder.create <shape::NumElementsOp>(
161- loc, indexType, shapeRight);
150+ auto shapeLeft =
151+ builder
152+ .create <shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
153+ inputShape, axisValue)
154+ .getResult (0 );
155+ auto sizeLeft =
156+ builder.create <shape::NumElementsOp>(loc, indexType, shapeLeft);
157+ auto shapeRight =
158+ builder
159+ .create <shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
160+ inputShape, axisNextValue)
161+ .getResult (1 );
162+ auto sizeRight =
163+ builder.create <shape::NumElementsOp>(loc, indexType, shapeRight);
162164
163165 // Compute flat input shape as a 3-element 1D tensor
164166 auto axisSizeValue = builder.create <arith::ConstantIndexOp>(loc, axisSize);
@@ -171,8 +173,8 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
171173 auto elementType = inputType.getElementType ();
172174 auto flatInputType = RankedTensorType::get (
173175 {ShapedType::kDynamic , axisSize, ShapedType::kDynamic }, elementType);
174- auto flatInput = builder.create <tensor::ReshapeOp>(
175- loc, flatInputType, input, flatInputShape);
176+ auto flatInput = builder.create <tensor::ReshapeOp>(loc, flatInputType, input,
177+ flatInputShape);
176178
177179 return std::make_pair (flatInput, inputShape);
178180}
@@ -190,7 +192,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
190192 auto inputType = cast<RankedTensorType>(input.getType ());
191193 auto elementType = inputType.getElementType ();
192194 auto unrankedType = UnrankedTensorType::get (elementType);
193- return builder.create <tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
195+ return builder.create <tensor::ReshapeOp>(loc, unrankedType, input,
196+ inputShape);
194197}
195198
196199// Create a tensor constant containing all scales in a per-channel quantized
@@ -209,7 +212,8 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
209212 auto scaleAttrs = llvm::map_to_vector (scales, [&](double scale) -> Attribute {
210213 return builder.getFloatAttr (expressedType, scale);
211214 });
212- auto tensorType = RankedTensorType::get ({(int64_t ) scales.size ()}, expressedType);
215+ auto tensorType =
216+ RankedTensorType::get ({(int64_t )scales.size ()}, expressedType);
213217 auto scalesAttr = DenseElementsAttr::get (tensorType, scaleAttrs);
214218 return builder.create <arith::ConstantOp>(loc, tensorType, scalesAttr);
215219}
@@ -228,9 +232,8 @@ Value materializePerChannelZeroPoints(
228232 UniformQuantizedPerAxisType quantizedType) {
229233 auto zeroPoints = quantizedType.getZeroPoints ();
230234 auto storageType = quantizedType.getStorageType ();
231- auto zeroPointAttrs = llvm::map_to_vector (
232- zeroPoints,
233- [&](int64_t zeroPoint) -> Attribute {
235+ auto zeroPointAttrs =
236+ llvm::map_to_vector (zeroPoints, [&](int64_t zeroPoint) -> Attribute {
234237 return builder.getIntegerAttr (storageType, zeroPoint);
235238 });
236239 auto tensorType =
@@ -239,6 +242,54 @@ Value materializePerChannelZeroPoints(
239242 return builder.create <arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
240243}
241244
245+ // Create a tensor constant containing all scales in a sub-channel quantized
246+ // type. Example:
247+ //
248+ // !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
249+ //
250+ // produces
251+ //
252+ // %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
253+ //
254+ Value materializeSubChannelScales (
255+ OpBuilder &builder, Location loc,
256+ UniformQuantizedSubChannelType quantizedType) {
257+ auto scales = quantizedType.getScales ();
258+ auto expressedType = quantizedType.getExpressedType ();
259+ auto scaleAttrs = llvm::map_to_vector (
260+ scales.getValues <APFloat>(), [&](APFloat scale) -> Attribute {
261+ return builder.getFloatAttr (expressedType, scale);
262+ });
263+ auto tensorType =
264+ RankedTensorType::get (scales.getType ().getShape (), expressedType);
265+ auto scalesAttr = DenseElementsAttr::get (tensorType, scaleAttrs);
266+ return builder.create <arith::ConstantOp>(loc, tensorType, scalesAttr);
267+ }
268+
269+ // Create a tensor constant containing all zero points in a sub-channel
270+ // quantized type. Example:
271+ //
272+ // !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
273+ //
274+ // produces
275+ //
276+ // %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
277+ //
278+ Value materializeSubChannelZeroPoints (
279+ OpBuilder &builder, Location loc,
280+ UniformQuantizedSubChannelType quantizedType) {
281+ auto zeroPoints = quantizedType.getZeroPoints ();
282+ auto storageType = quantizedType.getStorageType ();
283+ auto zeroPointAttrs = llvm::map_to_vector (
284+ zeroPoints.getValues <APInt>(), [&](APInt zeroPoint) -> Attribute {
285+ return builder.getIntegerAttr (storageType, zeroPoint);
286+ });
287+ auto tensorType =
288+ RankedTensorType::get (zeroPoints.getType ().getShape (), storageType);
289+ auto zeroPointsAttr = DenseElementsAttr::get (tensorType, zeroPointAttrs);
290+ return builder.create <arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
291+ }
292+
242293// Clamp the given scalar or tensor input using the storage bounds encoded in
243294// the given quantized type, if present.
244295//
@@ -299,7 +350,7 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
299350 return builder.create <arith::UIToFPOp>(loc, resultType, input);
300351}
301352
302- // Quantize a scalar or ranked tensor value. The stored value is clamped using
353+ // Quantize a scalar or ranked tensor value. The stored value is clamped using
303354// the storage bounds encoded in the given quantized type.
304355//
305356// See function 'convertRanked()' below for a description of the arguments.
@@ -308,8 +359,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
308359 Value zeroPoint, QuantizedType quantizedType) {
309360 // Convert scale to tensor if necessary
310361 auto inputType = input.getType ();
311- scale = getScalarOrTensorConstant (
312- builder, loc, scale, inputType, inputShape);
362+ scale = getScalarOrTensorConstant (builder, loc, scale, inputType, inputShape);
313363
314364 // Scale input
315365 auto scaledValue = builder.create <arith::DivFOp>(loc, input, scale);
@@ -322,8 +372,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
322372 inputShape);
323373
324374 // Convert zero point from storage to expressed type
325- zeroPoint = convertIntegerToFloat (builder, loc, zeroPoint,
326- scale.getType (),
375+ zeroPoint = convertIntegerToFloat (builder, loc, zeroPoint, scale.getType (),
327376 quantizedType.isSigned ());
328377
329378 // Add zero point to stored value
@@ -334,9 +383,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
334383 // Convert stored value to storage type
335384 auto storageScalarOrTensorType =
336385 getScalarOrTensorType (quantizedType.getStorageType (), inputType);
337- auto storedValueInt = convertFloatToInteger (
338- builder, loc, storedValueFloat, storageScalarOrTensorType,
339- quantizedType.isSigned ());
386+ auto storedValueInt = convertFloatToInteger (builder, loc, storedValueFloat,
387+ storageScalarOrTensorType,
388+ quantizedType.isSigned ());
340389
341390 // Clamp stored value it if the storage type is bound
342391 auto storedValueClamped = clampScalarOrTensor (builder, loc, storedValueInt,
@@ -352,12 +401,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
352401 Value zeroPoint, QuantizedType quantizedType) {
353402 // Convert scale to tensor if necessary
354403 auto inputType = input.getType ();
355- scale = getScalarOrTensorConstant (
356- builder, loc, scale, inputType, inputShape);
404+ scale = getScalarOrTensorConstant (builder, loc, scale, inputType, inputShape);
357405
358406 // Convert stored value to float
359- auto result = convertIntegerToFloat (
360- builder, loc, input, scale. getType (), quantizedType.isSigned ());
407+ auto result = convertIntegerToFloat (builder, loc, input, scale. getType (),
408+ quantizedType.isSigned ());
361409
362410 // Skip unnecessary computations if no zero point is given
363411 if (!matchPattern (zeroPoint, m_Zero ())) {
@@ -366,8 +414,7 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
366414 inputShape);
367415
368416 // Convert zero point from storage to expressed type
369- zeroPoint = convertIntegerToFloat (builder, loc, zeroPoint,
370- scale.getType (),
417+ zeroPoint = convertIntegerToFloat (builder, loc, zeroPoint, scale.getType (),
371418 quantizedType.isSigned ());
372419
373420 // Subtract zero point to stored value
@@ -501,35 +548,33 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
501548 auto initShape = tensor::getMixedSizes (builder, loc, input);
502549 Value init = builder.create <tensor::EmptyOp>(loc, initShape, elementType);
503550
504- SmallVector<utils::IteratorType> iteratorTypes (
505- inputRank, utils::IteratorType::parallel);
551+ SmallVector<utils::IteratorType> iteratorTypes (inputRank,
552+ utils::IteratorType::parallel);
506553 auto channelAxisAffineMap = AffineMap::get (
507554 inputRank, 0 , builder.getAffineDimExpr (channelAxis), context);
508555 SmallVector<AffineMap> indexingMaps{
509- builder.getMultiDimIdentityMap (inputRank),
510- channelAxisAffineMap,
511- channelAxisAffineMap,
512- builder.getMultiDimIdentityMap (inputRank)
513- };
514- auto result = builder.create <linalg::GenericOp>(
515- loc,
516- init.getType (), // resultType
517- ValueRange{input, scales, zeroPoints}, // inputs
518- ValueRange{init}, // outputs
519- indexingMaps,
520- iteratorTypes,
521- [&](OpBuilder& builder, Location loc, ValueRange args) {
522- assert (args.size () == 4 );
523- auto input = args[0 ];
524- auto scale = args[1 ];
525- auto zeroPoint = args[2 ];
526-
527- auto result = convertRanked (builder, loc, op, input, {}, scale,
528- zeroPoint, quantizedType);
529-
530- builder.create <linalg::YieldOp>(loc, result);
531- })
532- .getResult (0 );
556+ builder.getMultiDimIdentityMap (inputRank), channelAxisAffineMap,
557+ channelAxisAffineMap, builder.getMultiDimIdentityMap (inputRank)};
558+ auto result = builder
559+ .create <linalg::GenericOp>(
560+ loc,
561+ init.getType (), // resultType
562+ ValueRange{input, scales, zeroPoints}, // inputs
563+ ValueRange{init}, // outputs
564+ indexingMaps, iteratorTypes,
565+ [&](OpBuilder &builder, Location loc, ValueRange args) {
566+ assert (args.size () == 4 );
567+ auto input = args[0 ];
568+ auto scale = args[1 ];
569+ auto zeroPoint = args[2 ];
570+
571+ auto result =
572+ convertRanked (builder, loc, op, input, {}, scale,
573+ zeroPoint, quantizedType);
574+
575+ builder.create <linalg::YieldOp>(loc, result);
576+ })
577+ .getResult (0 );
533578
534579 return result;
535580}
@@ -551,7 +596,7 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
551596 // Flatten unranked tensor into a 3D ranked tensor if necessary
552597 bool isUnranked = isa<UnrankedTensorType>(input.getType ());
553598 int64_t channelAxis = quantizedType.getQuantizedDimension ();
554- int64_t channelAxisSize = (int64_t ) quantizedType.getScales ().size ();
599+ int64_t channelAxisSize = (int64_t )quantizedType.getScales ().size ();
555600 Value inputShape;
556601 if (isUnranked) {
557602 std::tie (input, inputShape) = flattenUnrankedTensorAroundAxis (
@@ -660,11 +705,17 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
660705 return convertPerChannel (builder, loc, op, input,
661706 uniformQuantizedPerAxisType);
662707
708+ if (auto uniformQuantizedSubChannelType =
709+ dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
710+ return convertSubChannel (builder, loc, op, input,
711+ uniformQuantizedSubChannelType);
712+
663713 llvm_unreachable (" unexpected quantized type" );
664714}
665715
666716// Lowering pattern for 'quant.dcast'
667- struct DequantizeCastOpConversion : public OpConversionPattern <quant::DequantizeCastOp> {
717+ struct DequantizeCastOpConversion
718+ : public OpConversionPattern<quant::DequantizeCastOp> {
668719 using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
669720
670721 LogicalResult
@@ -689,7 +740,8 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
689740};
690741
691742// Lowering pattern for 'quant.qcast'
692- struct QuantizeCastOpConversion : public OpConversionPattern <quant::QuantizeCastOp> {
743+ struct QuantizeCastOpConversion
744+ : public OpConversionPattern<quant::QuantizeCastOp> {
693745 using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
694746
695747 LogicalResult
@@ -717,12 +769,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
717769 ConversionTarget target (getContext ());
718770 target.addLegalOp <quant::StorageCastOp>();
719771 target.addIllegalDialect <quant::QuantDialect>();
720- target.addLegalDialect <
721- arith::ArithDialect,
722- linalg::LinalgDialect,
723- shape::ShapeDialect,
724- tensor::TensorDialect
725- >();
772+ target.addLegalDialect <arith::ArithDialect, linalg::LinalgDialect,
773+ shape::ShapeDialect, tensor::TensorDialect>();
726774
727775 if (failed (applyPartialConversion (getOperation (), target,
728776 std::move (patterns))))
@@ -733,10 +781,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
733781} // namespace
734782
735783void populateLowerQuantOpsPatterns (RewritePatternSet &patterns) {
736- patterns.add <
737- DequantizeCastOpConversion,
738- QuantizeCastOpConversion
739- >(patterns.getContext ());
784+ patterns.add <DequantizeCastOpConversion, QuantizeCastOpConversion>(
785+ patterns.getContext ());
740786}
741787
742788} // namespace quant
0 commit comments