Skip to content

Commit 6240480

Browse files
[ONNX] Modify QLinear* ops lowering to use common utilities (#4178)
Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent bdad744 commit 6240480

File tree

2 files changed

+78
-167
lines changed

2 files changed

+78
-167
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 74 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -469,95 +469,49 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
469469
return rewriter.notifyMatchFailure(
470470
binder.op, "Unimplemented: expected 8 input operands");
471471

472-
Value a = operands[0];
473-
Value aScale = operands[1];
474-
Value aZp = operands[2];
475-
Value b = operands[3];
476-
Value bScale = operands[4];
477-
Value bZp = operands[5];
478-
Value cScale = operands[6];
479-
Value cZp = operands[7];
480-
481-
auto check = [](Value v) {
482-
auto vTy = cast<Torch::ValueTensorType>(v.getType());
483-
for (auto dim : vTy.getSizes())
484-
if (dim != 1)
485-
return false;
486-
return true;
487-
};
488-
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
489-
!check(cScale) || !check(cZp))
490-
return rewriter.notifyMatchFailure(
491-
binder.op, "Unsupported per-tensor quantization");
492-
493-
Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
494-
binder.getLoc(),
495-
rewriter.getType<Torch::ListType>(
496-
rewriter.getType<Torch::IntType>()),
497-
ValueRange{});
498-
auto extract = [&rewriter, &binder, &emptyList](Value v) {
499-
auto vTy = cast<Torch::ValueTensorType>(v.getType());
500-
if (!vTy.getSizes().empty()) {
501-
vTy = rewriter.getType<Torch::ValueTensorType>(
502-
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
503-
v = rewriter.create<Torch::AtenReshapeOp>(binder.getLoc(), vTy, v,
504-
emptyList);
505-
}
506-
507-
Type extractTy = rewriter.getType<Torch::FloatType>();
508-
if (isa<IntegerType>(vTy.getDtype()))
509-
extractTy = rewriter.getType<Torch::IntType>();
472+
Value a, aScale, aZp, b, bScale, bZp, cScale, cZp;
510473

511-
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
512-
v);
513-
};
514-
515-
aZp = extract(aZp);
516-
bZp = extract(bZp);
517-
cZp = extract(cZp);
518-
519-
aScale = extract(aScale);
520-
bScale = extract(bScale);
521-
cScale = extract(cScale);
522-
523-
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
524-
Value zp) -> Value {
525-
auto ty = cast<Torch::ValueTensorType>(v.getType());
526-
auto newTy = getQTorchTypeFromTorchIntType(ty);
527-
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
528-
binder.getLoc(), newTy, v, scale, zp);
529-
};
530-
531-
a = makePerTensor(a, aScale, aZp);
532-
b = makePerTensor(b, bScale, bZp);
474+
if (failed(extractPerTensorQuantizationArguments(
475+
rewriter, loc, /*scale=*/operands[1],
476+
/*zero_point=*/operands[2], aScale, aZp)))
477+
return rewriter.notifyMatchFailure(
478+
binder.op, "Incompatible arguments for per-tensor quantization");
533479

534-
auto aTy = dyn_cast<Torch::ValueTensorType>(a.getType());
535-
if (!aTy || !aTy.hasSizes())
480+
if (failed(extractPerTensorQuantizationArguments(
481+
rewriter, loc, /*scale=*/operands[4],
482+
/*zero_point=*/operands[5], bScale, bZp)))
536483
return rewriter.notifyMatchFailure(
537-
binder.op, "Expected input argument `a` to have sizes");
484+
binder.op, "Incompatible arguments for per-tensor quantization");
538485

539-
aTy = rewriter.getType<Torch::ValueTensorType>(aTy.getOptionalSizes(),
540-
rewriter.getF32Type());
541-
a = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), aTy,
542-
a);
486+
if (failed(extractPerTensorQuantizationArguments(
487+
rewriter, loc, /*scale=*/operands[6],
488+
/*zero_point=*/operands[7], cScale, cZp)))
489+
return rewriter.notifyMatchFailure(
490+
binder.op, "Incompatible arguments for per-tensor quantization");
543491

544-
auto bTy = dyn_cast<Torch::ValueTensorType>(b.getType());
545-
if (!bTy || !bTy.hasSizes())
492+
if (failed(createDequantizeTensor(rewriter, loc, /*input=*/operands[0],
493+
/*scale=*/aScale, /*zero_point=*/aZp,
494+
/*output=*/a)))
546495
return rewriter.notifyMatchFailure(
547-
binder.op, "Expected input argument `b` to have sizes");
496+
binder.op, "Failed to dequantize the input tensor `a` because of "
497+
"missing sizes");
548498

549-
bTy = rewriter.getType<Torch::ValueTensorType>(bTy.getOptionalSizes(),
550-
rewriter.getF32Type());
551-
b = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), bTy,
552-
b);
499+
if (failed(createDequantizeTensor(rewriter, loc, /*input=*/operands[3],
500+
/*scale=*/bScale, /*zero_point=*/bZp,
501+
/*output=*/b)))
502+
return rewriter.notifyMatchFailure(
503+
binder.op, "Failed to dequantize the input tensor `b` because of "
504+
"missing sizes");
553505

506+
// Computing the result of "Add".
554507
auto cTy = rewriter.getType<Torch::ValueTensorType>(
555508
resultType.getOptionalSizes(), rewriter.getF32Type());
556509
Value alpha = rewriter.create<Torch::ConstantFloatOp>(
557510
loc, rewriter.getF64FloatAttr(1.0));
558511
Value c = rewriter.create<Torch::AtenAddTensorOp>(binder.getLoc(), cTy,
559512
a, b, alpha);
560513

514+
// Quantizing the result of "Add" operation.
561515
cTy = dyn_cast<Torch::ValueTensorType>(
562516
getQTorchTypeFromTorchIntType(resultType));
563517
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
@@ -588,11 +542,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
588542
return rewriter.notifyMatchFailure(
589543
binder.op, "Unimplemented: expected 5 input operands");
590544

591-
Value x = operands[0];
592-
Value xScale = operands[1];
593-
Value xZp = operands[2];
594-
Value yScale = operands[3];
595-
Value yZp = operands[4];
545+
Value x, xScale, xZp, yScale, yZp;
596546

597547
if (failed(extractPerTensorQuantizationArguments(
598548
rewriter, loc, /*scale=*/operands[1],
@@ -606,18 +556,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
606556
return rewriter.notifyMatchFailure(
607557
binder.op, "Incompatible arguments for per-tensor quantization");
608558

609-
auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType());
610-
if (!xTy || !xTy.hasSizes())
559+
if (failed(createDequantizeTensor(rewriter, loc, /*input=*/operands[0],
560+
/*scale=*/xScale, /*zero_point=*/xZp,
561+
/*output=*/x)))
611562
return rewriter.notifyMatchFailure(
612-
binder.op, "Expected input argument `x` to have sizes");
613-
614-
xTy = getQTorchTypeFromTorchIntType(xTy);
615-
x = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
616-
loc, xTy, x, xScale, xZp);
617-
xTy = rewriter.getType<Torch::ValueTensorType>(xTy.getSizes(),
618-
rewriter.getF32Type());
619-
// Dequantizing the input tensor `x`.
620-
x = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, xTy, x);
563+
binder.op, "Failed to dequantize the input tensor `x` because of "
564+
"missing sizes");
621565

622566
// Computing the LeakyRelu result.
623567
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
@@ -670,16 +614,8 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
670614
binder.op, "Incompatible number of input operands, scales and/or "
671615
"zero-points");
672616

673-
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
674-
Value zp) -> Value {
675-
auto ty = cast<Torch::ValueTensorType>(v.getType());
676-
auto newTy = getQTorchTypeFromTorchIntType(ty);
677-
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
678-
binder.getLoc(), newTy, v, scale, zp);
679-
};
680-
681-
// Preparing the quantized inputs.
682-
SmallVector<Value> quantizedInputs;
617+
// Preparing the dequantized inputs.
618+
SmallVector<Value> dequantizedInputs;
683619
for (unsigned i = 0; i < numInputs; i++) {
684620
Value scale, zeroPoint;
685621
if (failed(extractPerTensorQuantizationArguments(
@@ -689,24 +625,15 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
689625
binder.op, "Incompatible scale and zero-points argument for "
690626
"per-tensor quantization");
691627

692-
quantizedInputs.push_back(makePerTensor(inputs[i], scale, zeroPoint));
693-
}
694-
695-
// Dequantizing the inputs.
696-
SmallVector<Value> dequantizedInputs;
697-
for (unsigned i = 0; i < numInputs; i++) {
698-
Torch::ValueTensorType inputTy =
699-
dyn_cast<Torch::ValueTensorType>(quantizedInputs[i].getType());
700-
if (!inputTy || !inputTy.hasSizes())
628+
Value dequantizedInput;
629+
if (failed(createDequantizeTensor(rewriter, loc, inputs[i], scale,
630+
zeroPoint,
631+
/*output=*/dequantizedInput)))
701632
return rewriter.notifyMatchFailure(
702-
binder.op, "Expected tensor input operands to be concatenated "
703-
"to have sizes");
704-
705-
inputTy = rewriter.getType<Torch::ValueTensorType>(
706-
inputTy.getOptionalSizes(), rewriter.getF32Type());
707-
dequantizedInputs.push_back(
708-
rewriter.create<Torch::AtenDequantizeSelfOp>(loc, inputTy,
709-
quantizedInputs[i]));
633+
binder.op, "Failed to dequantize the input tensor because of "
634+
"missing sizes");
635+
636+
dequantizedInputs.push_back(dequantizedInput);
710637
}
711638

712639
// Concatenating the inputs.
@@ -764,8 +691,19 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
764691
binder.op,
765692
"Unimplemented: support not present for channels_last attribute");
766693

767-
Value x = operands[0];
768-
Value xScale, xZp, yScale, yZp;
694+
auto xTy = dyn_cast<Torch::ValueTensorType>(operands[0].getType());
695+
if (!xTy || !xTy.hasSizes())
696+
return rewriter.notifyMatchFailure(
697+
binder.op, "Expected input argument `x` to have sizes");
698+
ArrayRef<int64_t> inputShape = xTy.getSizes();
699+
700+
if (!resultType || !resultType.hasSizes()) {
701+
return rewriter.notifyMatchFailure(
702+
binder.op, "Expected result type having sizes");
703+
}
704+
ArrayRef<int64_t> resultShape = resultType.getSizes();
705+
706+
Value x, xScale, xZp, yScale, yZp;
769707

770708
if (failed(extractPerTensorQuantizationArguments(
771709
rewriter, loc, /*scale=*/operands[1],
@@ -779,25 +717,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
779717
return rewriter.notifyMatchFailure(
780718
binder.op, "Incompatible arguments for per-tensor quantization");
781719

782-
auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType());
783-
if (!xTy || !xTy.hasSizes())
784-
return rewriter.notifyMatchFailure(
785-
binder.op, "Expected input argument `x` to have sizes");
786-
ArrayRef<int64_t> inputShape = xTy.getSizes();
787-
788-
xTy = getQTorchTypeFromTorchIntType(xTy);
789-
x = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
790-
loc, xTy, x, xScale, xZp);
791-
xTy = rewriter.getType<Torch::ValueTensorType>(inputShape,
792-
rewriter.getF32Type());
793-
// Dequantizing the input tensor `x`.
794-
x = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, xTy, x);
795-
796-
if (!resultType || !resultType.hasSizes()) {
720+
if (failed(createDequantizeTensor(rewriter, loc, /*input=*/operands[0],
721+
/*scale=*/xScale, /*zero_point=*/xZp,
722+
/*output=*/x)))
797723
return rewriter.notifyMatchFailure(
798-
binder.op, "Expected result type having sizes");
799-
}
800-
ArrayRef<int64_t> resultShape = resultType.getSizes();
724+
binder.op, "Failed to dequantize the input tensor `x` because of "
725+
"missing sizes");
801726

802727
// Computing the AvgPool result.
803728
SmallVector<Value> cstKernel, cstPadding, cstStrides;
@@ -888,8 +813,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
888813
return rewriter.notifyMatchFailure(
889814
binder.op, "Unimplemented: expected 5 input operands");
890815

891-
Value x = operands[0];
892-
Value xScale, xZp, yScale, yZp;
816+
Value x, xScale, xZp, yScale, yZp;
893817

894818
if (failed(extractPerTensorQuantizationArguments(
895819
rewriter, loc, /*scale=*/operands[1],
@@ -903,18 +827,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
903827
return rewriter.notifyMatchFailure(
904828
binder.op, "Incompatible arguments for per-tensor quantization");
905829

906-
auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType());
907-
if (!xTy || !xTy.hasSizes())
830+
if (failed(createDequantizeTensor(rewriter, loc, /*input=*/operands[0],
831+
/*scale=*/xScale, /*zero_point=*/xZp,
832+
/*output=*/x)))
908833
return rewriter.notifyMatchFailure(
909-
binder.op, "Expected input argument `x` to have sizes");
910-
911-
xTy = getQTorchTypeFromTorchIntType(xTy);
912-
x = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
913-
loc, xTy, x, xScale, xZp);
914-
xTy = rewriter.getType<Torch::ValueTensorType>(xTy.getSizes(),
915-
rewriter.getF32Type());
916-
// Dequantizing the input tensor `x`.
917-
x = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, xTy, x);
834+
binder.op, "Failed to dequantize the input tensor `x` because of "
835+
"missing sizes");
918836

919837
// Computing the Sigmoid result.
920838
auto yTy = rewriter.getType<Torch::ValueTensorType>(
@@ -958,8 +876,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
958876
return rewriter.notifyMatchFailure(
959877
binder.op, "Unimplemented: expected 5 input operands");
960878

961-
Value x = operands[0];
962-
Value xScale, xZp, yScale, yZp;
879+
Value x, xScale, xZp, yScale, yZp;
963880

964881
if (failed(extractPerTensorQuantizationArguments(
965882
rewriter, loc, /*scale=*/operands[1],
@@ -973,18 +890,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
973890
return rewriter.notifyMatchFailure(
974891
binder.op, "Incompatible arguments for per-tensor quantization");
975892

976-
auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType());
977-
if (!xTy || !xTy.hasSizes())
893+
if (failed(createDequantizeTensor(rewriter, loc, /*input=*/operands[0],
894+
/*scale=*/xScale, /*zero_point=*/xZp,
895+
/*output=*/x)))
978896
return rewriter.notifyMatchFailure(
979-
binder.op, "Expected input argument `x` to have sizes");
980-
981-
xTy = getQTorchTypeFromTorchIntType(xTy);
982-
x = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
983-
loc, xTy, x, xScale, xZp);
984-
xTy = rewriter.getType<Torch::ValueTensorType>(xTy.getSizes(),
985-
rewriter.getF32Type());
986-
// Dequantizing the input tensor `x`.
987-
x = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, xTy, x);
897+
binder.op, "Failed to dequantize the input tensor `x` because of "
898+
"missing sizes");
988899

989900
// Creating Onnx.AveragePool op.
990901
llvm::SmallVector<Value> newOperands = {x};

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3706,8 +3706,8 @@ func.func @test_qlinearadd(%arg0: !torch.vtensor<[1,4096],ui8>, %arg1: !torch.vt
37063706
// CHECK-DAG: %[[CSCALE:.+]] = torch.aten.item %[[C_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
37073707
// CHECK-DAG: %[[A_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[A]], %[[ASCALE]], %[[AZP]] : !torch.vtensor<[1,4096],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,4096],!torch.quint8>
37083708
// CHECK-DAG: %[[B_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[B]], %[[BSCALE]], %[[BZP]] : !torch.vtensor<[4096],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4096],!torch.quint8>
3709-
// CHECK: %[[A_F32:.+]] = torch.aten.dequantize.self %[[A_QUANT]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],f32>
3710-
// CHECK: %[[B_F32:.+]] = torch.aten.dequantize.self %[[B_QUANT]] : !torch.vtensor<[4096],!torch.quint8> -> !torch.vtensor<[4096],f32>
3709+
// CHECK-DAG: %[[A_F32:.+]] = torch.aten.dequantize.self %[[A_QUANT]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],f32>
3710+
// CHECK-DAG: %[[B_F32:.+]] = torch.aten.dequantize.self %[[B_QUANT]] : !torch.vtensor<[4096],!torch.quint8> -> !torch.vtensor<[4096],f32>
37113711
// CHECK: %[[ALPHA:.+]] = torch.constant.float 1.000000e+00
37123712
// CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[A_F32]], %[[B_F32]], %[[ALPHA]] : !torch.vtensor<[1,4096],f32>, !torch.vtensor<[4096],f32>, !torch.float -> !torch.vtensor<[1,4096],f32>
37133713
// CHECK: %[[DTY:.+]] = torch.constant.int 13
@@ -3752,8 +3752,8 @@ func.func @test_qlinearconcat(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtens
37523752
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
37533753
// CHECK-DAG: %[[QUANT_INPUT_1:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg2, %{{.+}}, %{{.+}} : !torch.vtensor<[?,?,?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.quint8>
37543754
// CHECK-DAG: %[[QUANT_INPUT_2:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg5, %{{.+}}, %{{.+}} : !torch.vtensor<[?,?,?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.quint8>
3755-
// CHECK: %[[DEQUANT_INPUT_1:.+]] = torch.aten.dequantize.self %[[QUANT_INPUT_1]] : !torch.vtensor<[?,?,?,?],!torch.quint8> -> !torch.vtensor<[?,?,?,?],f32>
3756-
// CHECK: %[[DEQUANT_INPUT_2:.+]] = torch.aten.dequantize.self %[[QUANT_INPUT_2]] : !torch.vtensor<[?,?,?,?],!torch.quint8> -> !torch.vtensor<[?,?,?,?],f32>
3755+
// CHECK-DAG: %[[DEQUANT_INPUT_1:.+]] = torch.aten.dequantize.self %[[QUANT_INPUT_1]] : !torch.vtensor<[?,?,?,?],!torch.quint8> -> !torch.vtensor<[?,?,?,?],f32>
3756+
// CHECK-DAG: %[[DEQUANT_INPUT_2:.+]] = torch.aten.dequantize.self %[[QUANT_INPUT_2]] : !torch.vtensor<[?,?,?,?],!torch.quint8> -> !torch.vtensor<[?,?,?,?],f32>
37573757
// CHECK-DAG: %[[CONCAT_LIST:.+]] = torch.prim.ListConstruct %[[DEQUANT_INPUT_1]], %[[DEQUANT_INPUT_2]] : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>) -> !torch.list<vtensor>
37583758
// CHECK: %[[AXIS:.+]] = torch.constant.int 1
37593759
// CHECK: %[[CONCAT:.+]] = torch.aten.cat %[[CONCAT_LIST]], %[[AXIS]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>

0 commit comments

Comments
 (0)