Skip to content

Commit c7f8ac0

Browse files
[ONNX] Add per channel quantization support for Onnx.QLinearConv op (#3917)
This commit extends the OnnxToTorch Lowering for Onnx.QLinearConv op by adding the support for per channel quantization for the weight argument. Since the convolution operation in the downstream pipeline ("Linalg") does not support the per-channel quantization, hence we add the support by performing convolution over the dequantized input and weight and then quantizing the output. Fixes nod-ai/SHARK-ModelDev#894. Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 23610fc commit c7f8ac0

File tree

2 files changed

+231
-66
lines changed

2 files changed

+231
-66
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 179 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -326,30 +326,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
326326
patterns.onOp(
327327
"QLinearConv", 1,
328328
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
329+
Location loc = binder.getLoc();
329330
Torch::ValueTensorType resultType;
330331
llvm::SmallVector<Value> operands;
331332
if ((binder.tensorOperands(operands, 8) &&
332333
binder.tensorOperands(operands, 9)) ||
333334
binder.tensorResultType(resultType))
334335
return failure();
335-
Value a = operands[0];
336-
Value aScale = operands[1];
337-
Value aZp = operands[2];
338-
Value b = operands[3];
339-
Value bScale = operands[4];
340-
Value bZp = operands[5];
341-
Value cScale = operands[6];
342-
Value cZp = operands[7];
343-
Value c = operands.size() == 9 ? operands[8] : nullptr;
344-
345-
auto check = [](Value v) {
346-
auto vTy = cast<Torch::ValueTensorType>(v.getType());
347-
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
348-
};
349-
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
350-
!check(cScale) || !check(cScale))
351-
return rewriter.notifyMatchFailure(
352-
binder.op, "not supported for non per-tensor quantization");
336+
Value input = operands[0];
337+
Value inputScale = operands[1];
338+
Value inputZp = operands[2];
339+
Value weight = operands[3];
340+
Value weightScale = operands[4];
341+
Value weightZp = operands[5];
342+
Value outputScale = operands[6];
343+
Value outputZp = operands[7];
344+
Value bias = operands.size() == 9 ? operands[8] : nullptr;
353345

354346
auto extract = [&rewriter, &binder](Value v) {
355347
auto vTy = cast<Torch::ValueTensorType>(v.getType());
@@ -361,36 +353,153 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
361353
v);
362354
};
363355

364-
aZp = extract(aZp);
365-
bZp = extract(bZp);
366-
cZp = extract(cZp);
367-
aScale = extract(aScale);
368-
bScale = extract(bScale);
369-
cScale = extract(cScale);
356+
inputZp = extract(inputZp);
357+
outputZp = extract(outputZp);
358+
inputScale = extract(inputScale);
359+
outputScale = extract(outputScale);
370360

371-
auto make = [&rewriter, &binder](Value v, Value scale,
372-
Value zp) -> Value {
361+
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
362+
Value zp) -> Value {
373363
auto ty = cast<Torch::ValueTensorType>(v.getType());
374364
auto newTy = getQTorchTypeFromTorchIntType(ty);
375365
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
376366
binder.getLoc(), newTy, v, scale, zp);
377367
};
378368

379-
a = make(a, aScale, aZp);
380-
b = make(b, bScale, bZp);
369+
// The onnx's QLinearConv op allows per channel quantization only for
370+
// the weight tensor for axis = 0.
371+
bool isPerChannelQuantization = false;
372+
auto weightTy = dyn_cast<Torch::ValueTensorType>(weight.getType());
373+
auto weightScaleTy =
374+
dyn_cast<Torch::ValueTensorType>(weightScale.getType());
375+
auto weightZpTy = dyn_cast<Torch::ValueTensorType>(weightZp.getType());
376+
if (!weightTy || !weightScaleTy || !weightZpTy ||
377+
!weightTy.hasSizes() || !weightScaleTy.hasSizes() ||
378+
!weightZpTy.hasSizes())
379+
return rewriter.notifyMatchFailure(
380+
binder.op, "Expected weight, weight_scale, and weight_zero_point "
381+
"arguments to have sizes");
382+
ArrayRef<int64_t> weightShape(weightTy.getSizes());
383+
SmallVector<int64_t> weightScaleShape(weightScaleTy.getSizes());
384+
SmallVector<int64_t> weightZpShape(weightZpTy.getSizes());
385+
if (weightScaleShape.size() == 0 ||
386+
llvm::all_of(weightScaleShape, [](int64_t s) { return s == 1; })) {
387+
weightZp = extract(weightZp);
388+
weightScale = extract(weightScale);
389+
weight = makePerTensor(weight, weightScale, weightZp);
390+
} else if (weightScaleShape.size() == 1 &&
391+
weightScaleShape[0] != Torch::kUnknownSize &&
392+
weightScaleShape[0] == weightShape[0]) {
393+
// Since the convolution operation in the downstream pipeline
394+
// ("Linalg") does not support the per-channel quantization, hence for
395+
// this particular case we perform the convolution over the
396+
// dequantized input and weight instead of relying on the downstream
397+
// pipeline to handle this. This code can be removed and made similar
398+
// to the other paths in this lowering once the per-channel
399+
// quantization support is added in the downstream pipeline.
400+
isPerChannelQuantization = true;
401+
402+
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
403+
if (!inputTy || !inputTy.hasSizes())
404+
return rewriter.notifyMatchFailure(
405+
binder.op, "Expected input argument to have sizes");
381406

382-
auto cTy = rewriter.getType<Torch::ValueTensorType>(
383-
resultType.getOptionalSizes(),
384-
rewriter.getIntegerType(32, /*issigned=*/true));
407+
// Dequantizing the input
408+
// input = input.to(dtype=torch.float32)
409+
// input_dequant = (input - input_zero_point) * input_scale
385410

386-
// TODO(suderman): insert convolution operator.
387-
llvm::SmallVector<Value> newOperands = {a, b};
388-
if (c)
389-
newOperands.push_back(c);
411+
// Converting the input tensor to float32 type.
412+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
413+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
414+
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
415+
loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6));
416+
Type f32InputType = rewriter.getType<Torch::ValueTensorType>(
417+
inputTy.getSizes(), rewriter.getF32Type());
418+
input = rewriter.create<Torch::AtenToDtypeOp>(
419+
loc, f32InputType, input, float32Type,
420+
/*non_blocking=*/cstFalse,
421+
/*copy=*/cstFalse,
422+
/*memory_format=*/none);
390423

391-
cTy = rewriter.getType<Torch::ValueTensorType>(
392-
resultType.getOptionalSizes(),
393-
rewriter.getType<Torch::QInt32Type>());
424+
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
425+
loc, rewriter.getF64FloatAttr(1.0));
426+
input = rewriter.create<Torch::AtenSubScalarOp>(
427+
loc, f32InputType, input, inputZp, cstOne);
428+
input = rewriter.create<Torch::AtenMulScalarOp>(loc, f32InputType,
429+
input, inputScale);
430+
431+
// Dequantizing the weight
432+
// Shapes of the inputs are as follows:
433+
// weight = (M x C/group x k1 x k2 x … x kn)
434+
// weight_scale = (M)
435+
// weight_zero_point = (M)
436+
//
437+
// We unsqueeze the weight_scale and weight_zero_point to match the
438+
// rank of weight. After unsqueeze:
439+
// weight_scale = (M, 1, 1, ..., 1)
440+
// weight_zero_point = (M, 1, 1, ..., 1)
441+
//
442+
// Then, we compute the dequantized weight:
443+
// weight = weight.to(dtype=torch.float32)
444+
// weight_dequant = (weight - weight_zero_point) * weight_scale
445+
int64_t diffRank = weightShape.size() - weightScaleShape.size();
446+
for (int i = 1; i <= diffRank; i++) {
447+
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
448+
loc, rewriter.getI64IntegerAttr(i));
449+
450+
weightScaleShape.push_back(1);
451+
Type weightScaleUnsqueezeType = weightScaleTy.getWithSizesAndDtype(
452+
weightScaleShape, weightScaleTy.getOptionalDtype());
453+
weightScale = rewriter.create<Torch::AtenUnsqueezeOp>(
454+
loc, weightScaleUnsqueezeType, weightScale, cstDim);
455+
456+
weightZpShape.push_back(1);
457+
Type weightZpUnsqueezeType = weightZpTy.getWithSizesAndDtype(
458+
weightZpShape, weightZpTy.getOptionalDtype());
459+
weightZp = rewriter.create<Torch::AtenUnsqueezeOp>(
460+
loc, weightZpUnsqueezeType, weightZp, cstDim);
461+
}
462+
463+
// Converting the weight tensor to float32 type.
464+
Type f32WeightType = rewriter.getType<Torch::ValueTensorType>(
465+
weightShape, rewriter.getF32Type());
466+
weight = rewriter.create<Torch::AtenToDtypeOp>(
467+
loc, f32WeightType, weight, float32Type,
468+
/*non_blocking=*/cstFalse,
469+
/*copy=*/cstFalse,
470+
/*memory_format=*/none);
471+
472+
weight = rewriter.create<Torch::AtenSubTensorOp>(
473+
loc, f32WeightType, weight, weightZp, cstOne);
474+
weight = rewriter.create<Torch::AtenMulTensorOp>(loc, f32WeightType,
475+
weight, weightScale);
476+
477+
// Converting the bias tensor to float32 type.
478+
if (bias) {
479+
auto biasTy = dyn_cast<Torch::ValueTensorType>(bias.getType());
480+
if (!biasTy || !biasTy.hasSizes())
481+
return rewriter.notifyMatchFailure(
482+
binder.op, "Expected bias argument to have sizes");
483+
Type f32BiasType = rewriter.getType<Torch::ValueTensorType>(
484+
biasTy.getSizes(), rewriter.getF32Type());
485+
bias = rewriter.create<Torch::AtenToDtypeOp>(
486+
loc, f32BiasType, bias, float32Type,
487+
/*non_blocking=*/cstFalse,
488+
/*copy=*/cstFalse,
489+
/*memory_format=*/none);
490+
}
491+
492+
} else {
493+
llvm_unreachable("Unidentified case for weight quantization for "
494+
"Onnx.QLinearConv op");
495+
}
496+
497+
if (!isPerChannelQuantization)
498+
input = makePerTensor(input, inputScale, inputZp);
499+
500+
llvm::SmallVector<Value> newOperands = {input, weight};
501+
if (bias)
502+
newOperands.push_back(bias);
394503

395504
llvm::SmallVector<NamedAttribute> newAttributes;
396505
newAttributes.push_back(
@@ -402,36 +511,46 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
402511
newAttributes.push_back(namedAttr);
403512
}
404513

405-
c = rewriter
406-
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
407-
newAttributes,
408-
binder.op->getRegions().size())
409-
.getResult(0);
514+
Type convDtype =
515+
isPerChannelQuantization
516+
? cast<Type>(rewriter.getF32Type())
517+
: cast<Type>(rewriter.getType<Torch::QInt32Type>());
518+
auto outputTy = rewriter.getType<Torch::ValueTensorType>(
519+
resultType.getOptionalSizes(), convDtype);
520+
Value output = rewriter
521+
.create<Torch::OperatorOp>(
522+
binder.getLoc(), outputTy, newOperands,
523+
newAttributes, binder.op->getRegions().size())
524+
.getResult(0);
525+
526+
if (!isPerChannelQuantization) {
527+
Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
528+
binder.getLoc(), rewriter.getType<Torch::FloatType>(), inputScale,
529+
weightScale);
530+
Value outZp = rewriter.create<Torch::ConstantIntOp>(
531+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
532+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
533+
output = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
534+
binder.getLoc(), outputTy, output, outScale, outZp);
535+
outputTy = rewriter.getType<Torch::ValueTensorType>(
536+
resultType.getOptionalSizes(), rewriter.getF32Type());
410537

411-
Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
412-
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
413-
bScale);
414-
Value outZp = rewriter.create<Torch::ConstantIntOp>(
415-
binder.getLoc(), rewriter.getType<Torch::IntType>(),
416-
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
417-
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
418-
binder.getLoc(), cTy, c, outScale, outZp);
419-
cTy = rewriter.getType<Torch::ValueTensorType>(
420-
resultType.getOptionalSizes(), rewriter.getF32Type());
538+
output = rewriter.create<Torch::AtenDequantizeSelfOp>(
539+
binder.getLoc(), outputTy, output);
540+
}
421541

422-
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
423-
c);
424-
cTy = getQTorchTypeFromTorchIntType(resultType);
542+
outputTy = getQTorchTypeFromTorchIntType(resultType);
425543
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
426544
binder.getLoc(), rewriter.getType<Torch::IntType>(),
427545
rewriter.getIntegerAttr(
428546
rewriter.getIntegerType(64),
429547
static_cast<int64_t>(
430-
Torch::getScalarTypeForType(cTy.getDtype()))));
431-
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
432-
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
548+
Torch::getScalarTypeForType(outputTy.getDtype()))));
549+
550+
output = rewriter.create<Torch::AtenQuantizePerTensorOp>(
551+
binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal);
433552
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
434-
c);
553+
output);
435554
return success();
436555
});
437556
patterns.onOp(

0 commit comments

Comments
 (0)