@@ -326,30 +326,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
326
326
patterns.onOp (
327
327
" QLinearConv" , 1 ,
328
328
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
329
+ Location loc = binder.getLoc ();
329
330
Torch::ValueTensorType resultType;
330
331
llvm::SmallVector<Value> operands;
331
332
if ((binder.tensorOperands (operands, 8 ) &&
332
333
binder.tensorOperands (operands, 9 )) ||
333
334
binder.tensorResultType (resultType))
334
335
return failure ();
335
- Value a = operands[0 ];
336
- Value aScale = operands[1 ];
337
- Value aZp = operands[2 ];
338
- Value b = operands[3 ];
339
- Value bScale = operands[4 ];
340
- Value bZp = operands[5 ];
341
- Value cScale = operands[6 ];
342
- Value cZp = operands[7 ];
343
- Value c = operands.size () == 9 ? operands[8 ] : nullptr ;
344
-
345
- auto check = [](Value v) {
346
- auto vTy = cast<Torch::ValueTensorType>(v.getType ());
347
- return llvm::all_of (vTy.getSizes (), [](int64_t d) { return d == 1 ; });
348
- };
349
- if (!check (aScale) || !check (aZp) || !check (bScale) || !check (bZp) ||
350
- !check (cScale) || !check (cScale))
351
- return rewriter.notifyMatchFailure (
352
- binder.op , " not supported for non per-tensor quantization" );
336
+ Value input = operands[0 ];
337
+ Value inputScale = operands[1 ];
338
+ Value inputZp = operands[2 ];
339
+ Value weight = operands[3 ];
340
+ Value weightScale = operands[4 ];
341
+ Value weightZp = operands[5 ];
342
+ Value outputScale = operands[6 ];
343
+ Value outputZp = operands[7 ];
344
+ Value bias = operands.size () == 9 ? operands[8 ] : nullptr ;
353
345
354
346
auto extract = [&rewriter, &binder](Value v) {
355
347
auto vTy = cast<Torch::ValueTensorType>(v.getType ());
@@ -361,36 +353,153 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
361
353
v);
362
354
};
363
355
364
- aZp = extract (aZp);
365
- bZp = extract (bZp);
366
- cZp = extract (cZp);
367
- aScale = extract (aScale);
368
- bScale = extract (bScale);
369
- cScale = extract (cScale);
356
+ inputZp = extract (inputZp);
357
+ outputZp = extract (outputZp);
358
+ inputScale = extract (inputScale);
359
+ outputScale = extract (outputScale);
370
360
371
- auto make = [&rewriter, &binder](Value v, Value scale,
372
- Value zp) -> Value {
361
+ auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
362
+ Value zp) -> Value {
373
363
auto ty = cast<Torch::ValueTensorType>(v.getType ());
374
364
auto newTy = getQTorchTypeFromTorchIntType (ty);
375
365
return rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
376
366
binder.getLoc (), newTy, v, scale, zp);
377
367
};
378
368
379
- a = make (a, aScale, aZp);
380
- b = make (b, bScale, bZp);
369
+ // The onnx's QLinearConv op allows per channel quantization only for
370
+ // the weight tensor for axis = 0.
371
+ bool isPerChannelQuantization = false ;
372
+ auto weightTy = dyn_cast<Torch::ValueTensorType>(weight.getType ());
373
+ auto weightScaleTy =
374
+ dyn_cast<Torch::ValueTensorType>(weightScale.getType ());
375
+ auto weightZpTy = dyn_cast<Torch::ValueTensorType>(weightZp.getType ());
376
+ if (!weightTy || !weightScaleTy || !weightZpTy ||
377
+ !weightTy.hasSizes () || !weightScaleTy.hasSizes () ||
378
+ !weightZpTy.hasSizes ())
379
+ return rewriter.notifyMatchFailure (
380
+ binder.op , " Expected weight, weight_scale, and weight_zero_point "
381
+ " arguments to have sizes" );
382
+ ArrayRef<int64_t > weightShape (weightTy.getSizes ());
383
+ SmallVector<int64_t > weightScaleShape (weightScaleTy.getSizes ());
384
+ SmallVector<int64_t > weightZpShape (weightZpTy.getSizes ());
385
+ if (weightScaleShape.size () == 0 ||
386
+ llvm::all_of (weightScaleShape, [](int64_t s) { return s == 1 ; })) {
387
+ weightZp = extract (weightZp);
388
+ weightScale = extract (weightScale);
389
+ weight = makePerTensor (weight, weightScale, weightZp);
390
+ } else if (weightScaleShape.size () == 1 &&
391
+ weightScaleShape[0 ] != Torch::kUnknownSize &&
392
+ weightScaleShape[0 ] == weightShape[0 ]) {
393
+ // Since the convolution operation in the downstream pipeline
394
+ // ("Linalg") does not support the per-channel quantization, hence for
395
+ // this particular case we perform the convolution over the
396
+ // dequantized input and weight instead of relying on the downstream
397
+ // pipeline to handle this. This code can be removed and made similar
398
+ // to the other paths in this lowering once the per-channel
399
+ // quantization support is added in the downstream pipeline.
400
+ isPerChannelQuantization = true ;
401
+
402
+ auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType ());
403
+ if (!inputTy || !inputTy.hasSizes ())
404
+ return rewriter.notifyMatchFailure (
405
+ binder.op , " Expected input argument to have sizes" );
381
406
382
- auto cTy = rewriter. getType <Torch::ValueTensorType>(
383
- resultType. getOptionalSizes (),
384
- rewriter. getIntegerType ( 32 , /* issigned= */ true ));
407
+ // Dequantizing the input
408
+ // input = input.to(dtype=torch.float32)
409
+ // input_dequant = (input - input_zero_point) * input_scale
385
410
386
- // TODO(suderman): insert convolution operator.
387
- llvm::SmallVector<Value> newOperands = {a, b};
388
- if (c)
389
- newOperands.push_back (c);
411
+ // Converting the input tensor to float32 type.
412
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
413
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
414
+ Value float32Type = rewriter.create <Torch::ConstantIntOp>(
415
+ loc, rewriter.getI64IntegerAttr (/* float32Type*/ 6 ));
416
+ Type f32InputType = rewriter.getType <Torch::ValueTensorType>(
417
+ inputTy.getSizes (), rewriter.getF32Type ());
418
+ input = rewriter.create <Torch::AtenToDtypeOp>(
419
+ loc, f32InputType, input, float32Type,
420
+ /* non_blocking=*/ cstFalse,
421
+ /* copy=*/ cstFalse,
422
+ /* memory_format=*/ none);
390
423
391
- cTy = rewriter.getType <Torch::ValueTensorType>(
392
- resultType.getOptionalSizes (),
393
- rewriter.getType <Torch::QInt32Type>());
424
+ Value cstOne = rewriter.create <Torch::ConstantFloatOp>(
425
+ loc, rewriter.getF64FloatAttr (1.0 ));
426
+ input = rewriter.create <Torch::AtenSubScalarOp>(
427
+ loc, f32InputType, input, inputZp, cstOne);
428
+ input = rewriter.create <Torch::AtenMulScalarOp>(loc, f32InputType,
429
+ input, inputScale);
430
+
431
+ // Dequantizing the weight
432
+ // Shapes of the inputs are as follows:
433
+ // weight = (M x C/group x k1 x k2 x … x kn)
434
+ // weight_scale = (M)
435
+ // weight_zero_point = (M)
436
+ //
437
+ // We unsqueeze the weight_scale and weight_zero_point to match the
438
+ // rank of weight. After unsqueeze:
439
+ // weight_scale = (M, 1, 1, ..., 1)
440
+ // weight_zero_point = (M, 1, 1, ..., 1)
441
+ //
442
+ // Then, we compute the dequantized weight:
443
+ // weight = weight.to(dtype=torch.float32)
444
+ // weight_dequant = (weight - weight_zero_point) * weight_scale
445
+ int64_t diffRank = weightShape.size () - weightScaleShape.size ();
446
+ for (int i = 1 ; i <= diffRank; i++) {
447
+ Value cstDim = rewriter.create <Torch::ConstantIntOp>(
448
+ loc, rewriter.getI64IntegerAttr (i));
449
+
450
+ weightScaleShape.push_back (1 );
451
+ Type weightScaleUnsqueezeType = weightScaleTy.getWithSizesAndDtype (
452
+ weightScaleShape, weightScaleTy.getOptionalDtype ());
453
+ weightScale = rewriter.create <Torch::AtenUnsqueezeOp>(
454
+ loc, weightScaleUnsqueezeType, weightScale, cstDim);
455
+
456
+ weightZpShape.push_back (1 );
457
+ Type weightZpUnsqueezeType = weightZpTy.getWithSizesAndDtype (
458
+ weightZpShape, weightZpTy.getOptionalDtype ());
459
+ weightZp = rewriter.create <Torch::AtenUnsqueezeOp>(
460
+ loc, weightZpUnsqueezeType, weightZp, cstDim);
461
+ }
462
+
463
+ // Converting the weight tensor to float32 type.
464
+ Type f32WeightType = rewriter.getType <Torch::ValueTensorType>(
465
+ weightShape, rewriter.getF32Type ());
466
+ weight = rewriter.create <Torch::AtenToDtypeOp>(
467
+ loc, f32WeightType, weight, float32Type,
468
+ /* non_blocking=*/ cstFalse,
469
+ /* copy=*/ cstFalse,
470
+ /* memory_format=*/ none);
471
+
472
+ weight = rewriter.create <Torch::AtenSubTensorOp>(
473
+ loc, f32WeightType, weight, weightZp, cstOne);
474
+ weight = rewriter.create <Torch::AtenMulTensorOp>(loc, f32WeightType,
475
+ weight, weightScale);
476
+
477
+ // Converting the bias tensor to float32 type.
478
+ if (bias) {
479
+ auto biasTy = dyn_cast<Torch::ValueTensorType>(bias.getType ());
480
+ if (!biasTy || !biasTy.hasSizes ())
481
+ return rewriter.notifyMatchFailure (
482
+ binder.op , " Expected bias argument to have sizes" );
483
+ Type f32BiasType = rewriter.getType <Torch::ValueTensorType>(
484
+ biasTy.getSizes (), rewriter.getF32Type ());
485
+ bias = rewriter.create <Torch::AtenToDtypeOp>(
486
+ loc, f32BiasType, bias, float32Type,
487
+ /* non_blocking=*/ cstFalse,
488
+ /* copy=*/ cstFalse,
489
+ /* memory_format=*/ none);
490
+ }
491
+
492
+ } else {
493
+ llvm_unreachable (" Unidentified case for weight quantization for "
494
+ " Onnx.QLinearConv op" );
495
+ }
496
+
497
+ if (!isPerChannelQuantization)
498
+ input = makePerTensor (input, inputScale, inputZp);
499
+
500
+ llvm::SmallVector<Value> newOperands = {input, weight};
501
+ if (bias)
502
+ newOperands.push_back (bias);
394
503
395
504
llvm::SmallVector<NamedAttribute> newAttributes;
396
505
newAttributes.push_back (
@@ -402,36 +511,46 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
402
511
newAttributes.push_back (namedAttr);
403
512
}
404
513
405
- c = rewriter
406
- .create <Torch::OperatorOp>(binder.getLoc (), cTy, newOperands,
407
- newAttributes,
408
- binder.op ->getRegions ().size ())
409
- .getResult (0 );
514
+ Type convDtype =
515
+ isPerChannelQuantization
516
+ ? cast<Type>(rewriter.getF32Type ())
517
+ : cast<Type>(rewriter.getType <Torch::QInt32Type>());
518
+ auto outputTy = rewriter.getType <Torch::ValueTensorType>(
519
+ resultType.getOptionalSizes (), convDtype);
520
+ Value output = rewriter
521
+ .create <Torch::OperatorOp>(
522
+ binder.getLoc (), outputTy, newOperands,
523
+ newAttributes, binder.op ->getRegions ().size ())
524
+ .getResult (0 );
525
+
526
+ if (!isPerChannelQuantization) {
527
+ Value outScale = rewriter.create <Torch::AtenMulFloatOp>(
528
+ binder.getLoc (), rewriter.getType <Torch::FloatType>(), inputScale,
529
+ weightScale);
530
+ Value outZp = rewriter.create <Torch::ConstantIntOp>(
531
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
532
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
533
+ output = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
534
+ binder.getLoc (), outputTy, output, outScale, outZp);
535
+ outputTy = rewriter.getType <Torch::ValueTensorType>(
536
+ resultType.getOptionalSizes (), rewriter.getF32Type ());
410
537
411
- Value outScale = rewriter.create <Torch::AtenMulFloatOp>(
412
- binder.getLoc (), rewriter.getType <Torch::FloatType>(), aScale,
413
- bScale);
414
- Value outZp = rewriter.create <Torch::ConstantIntOp>(
415
- binder.getLoc (), rewriter.getType <Torch::IntType>(),
416
- rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
417
- c = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
418
- binder.getLoc (), cTy, c, outScale, outZp);
419
- cTy = rewriter.getType <Torch::ValueTensorType>(
420
- resultType.getOptionalSizes (), rewriter.getF32Type ());
538
+ output = rewriter.create <Torch::AtenDequantizeSelfOp>(
539
+ binder.getLoc (), outputTy, output);
540
+ }
421
541
422
- c = rewriter.create <Torch::AtenDequantizeSelfOp>(binder.getLoc (), cTy,
423
- c);
424
- cTy = getQTorchTypeFromTorchIntType (resultType);
542
+ outputTy = getQTorchTypeFromTorchIntType (resultType);
425
543
Value dtyVal = rewriter.create <Torch::ConstantIntOp>(
426
544
binder.getLoc (), rewriter.getType <Torch::IntType>(),
427
545
rewriter.getIntegerAttr (
428
546
rewriter.getIntegerType (64 ),
429
547
static_cast <int64_t >(
430
- Torch::getScalarTypeForType (cTy.getDtype ()))));
431
- c = rewriter.create <Torch::AtenQuantizePerTensorOp>(
432
- binder.getLoc (), cTy, c, cScale, cZp, dtyVal);
548
+ Torch::getScalarTypeForType (outputTy.getDtype ()))));
549
+
550
+ output = rewriter.create <Torch::AtenQuantizePerTensorOp>(
551
+ binder.getLoc (), outputTy, output, outputScale, outputZp, dtyVal);
433
552
rewriter.replaceOpWithNewOp <Torch::AtenIntReprOp>(binder.op , resultType,
434
- c );
553
+ output );
435
554
return success ();
436
555
});
437
556
patterns.onOp (
0 commit comments