Skip to content

Commit 9e5531e

Browse files
[ONNX] Add support for Onnx.QLinearConcat op (#4116)
This commit adds the Onnx->Torch lowering for [Onnx.QLinearConcat](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearConcat) op.
1 parent cccb0c0 commit 9e5531e

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,4 +645,107 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
645645
y);
646646
return success();
647647
});
648+
patterns.onOp(
649+
"QLinearConcat", 1,
650+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
651+
Location loc = binder.getLoc();
652+
Torch::ValueTensorType resultType;
653+
SmallVector<Value> operands;
654+
int64_t axis;
655+
if (binder.tensorOperandsList(operands) ||
656+
binder.s64IntegerAttr(axis, "axis") ||
657+
binder.tensorResultType(resultType))
658+
return failure();
659+
660+
SmallVector<Value> inputs, inputScales, inputZeroPoints;
661+
for (unsigned i = 2; i < operands.size(); i = i + 3) {
662+
inputs.push_back(operands[i]);
663+
inputScales.push_back(operands[i + 1]);
664+
inputZeroPoints.push_back(operands[i + 2]);
665+
}
666+
667+
unsigned numInputs = (operands.size() - 2) / 3;
668+
if (!(llvm::all_equal({inputs.size(), inputScales.size(),
669+
inputZeroPoints.size()}) &&
670+
inputs.size() == numInputs))
671+
return rewriter.notifyMatchFailure(
672+
binder.op, "Incompatible number of input operands, scales and/or "
673+
"zero-points");
674+
675+
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
676+
Value zp) -> Value {
677+
auto ty = cast<Torch::ValueTensorType>(v.getType());
678+
auto newTy = getQTorchTypeFromTorchIntType(ty);
679+
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
680+
binder.getLoc(), newTy, v, scale, zp);
681+
};
682+
683+
// Preparing the quantized inputs.
684+
SmallVector<Value> quantizedInputs;
685+
for (unsigned i = 0; i < numInputs; i++) {
686+
Value scale, zeroPoint;
687+
if (failed(extractPerTensorQuantizationArguments(
688+
rewriter, loc, /*scale=*/inputScales[i],
689+
/*zero_point=*/inputZeroPoints[i], scale, zeroPoint)))
690+
return rewriter.notifyMatchFailure(
691+
binder.op, "Incompatible scale and zero-points argument for "
692+
"per-tensor quantization");
693+
694+
quantizedInputs.push_back(makePerTensor(inputs[i], scale, zeroPoint));
695+
}
696+
697+
// Dequantizing the inputs.
698+
SmallVector<Value> dequantizedInputs;
699+
for (unsigned i = 0; i < numInputs; i++) {
700+
Torch::ValueTensorType inputTy =
701+
dyn_cast<Torch::ValueTensorType>(quantizedInputs[i].getType());
702+
if (!inputTy || !inputTy.hasSizes())
703+
return rewriter.notifyMatchFailure(
704+
binder.op, "Expected tensor input operands to be concatenated "
705+
"to have sizes");
706+
707+
inputTy = rewriter.getType<Torch::ValueTensorType>(
708+
inputTy.getOptionalSizes(), rewriter.getF32Type());
709+
dequantizedInputs.push_back(
710+
rewriter.create<Torch::AtenDequantizeSelfOp>(loc, inputTy,
711+
quantizedInputs[i]));
712+
}
713+
714+
// Concatenating the inputs.
715+
Type listElemType =
716+
cast<Torch::BaseTensorType>(dequantizedInputs[0].getType())
717+
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
718+
/*optionalDtype=*/nullptr);
719+
Type listType = Torch::ListType::get(listElemType);
720+
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
721+
binder.op->getLoc(), listType, dequantizedInputs);
722+
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
723+
loc, rewriter.getI64IntegerAttr(axis));
724+
auto concatTy = rewriter.getType<Torch::ValueTensorType>(
725+
resultType.getOptionalSizes(), rewriter.getF32Type());
726+
Value concat = rewriter.create<Torch::AtenCatOp>(loc, concatTy,
727+
tensorList, cstAxis);
728+
729+
// Quantizing the result of concatenated inputs.
730+
Value yScale, yZp;
731+
if (failed(extractPerTensorQuantizationArguments(
732+
rewriter, loc, /*scale=*/operands[0],
733+
/*zero_point=*/operands[1], yScale, yZp)))
734+
return rewriter.notifyMatchFailure(
735+
binder.op, "Incompatible scale and zero-points argument for "
736+
"per-tensor quantization");
737+
Torch::ValueTensorType yTy = dyn_cast<Torch::ValueTensorType>(
738+
getQTorchTypeFromTorchIntType(resultType));
739+
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
740+
loc, rewriter.getType<Torch::IntType>(),
741+
rewriter.getIntegerAttr(
742+
rewriter.getIntegerType(64),
743+
static_cast<int64_t>(
744+
Torch::getScalarTypeForType(yTy.getDtype()))));
745+
Value result = rewriter.create<Torch::AtenQuantizePerTensorOp>(
746+
loc, yTy, concat, yScale, yZp, dtyVal);
747+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
748+
result);
749+
return success();
750+
});
648751
}

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3743,3 +3743,23 @@ func.func @test_qlinearleakyrelu(%arg0: !torch.vtensor<[?,32,?,?],ui8>, %arg1: !
37433743
// CHECK: return %[[OUT]]
37443744
return %0 : !torch.vtensor<[?,32,?,?],ui8>
37453745
}
3746+
3747+
// -----
3748+
3749+
// CHECK-LABEL: @test_qlinearconcat(
3750+
func.func @test_qlinearconcat(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],ui8>, %arg2: !torch.vtensor<[?,?,?,?],ui8>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],ui8>, %arg5: !torch.vtensor<[?,?,?,?],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[?,?,?,?],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64} {
3751+
%0 = torch.operator "onnx.QLinearConcat"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[?,?,?,?],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[?,?,?,?],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[?,?,?,?],ui8>
3752+
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3753+
// CHECK-DAG: %[[QUANT_INPUT_1:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg2, %{{.+}}, %{{.+}} : !torch.vtensor<[?,?,?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.quint8>
3754+
// CHECK-DAG: %[[QUANT_INPUT_2:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg5, %{{.+}}, %{{.+}} : !torch.vtensor<[?,?,?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.quint8>
3755+
// CHECK: %[[DEQUANT_INPUT_1:.+]] = torch.aten.dequantize.self %[[QUANT_INPUT_1]] : !torch.vtensor<[?,?,?,?],!torch.quint8> -> !torch.vtensor<[?,?,?,?],f32>
3756+
// CHECK: %[[DEQUANT_INPUT_2:.+]] = torch.aten.dequantize.self %[[QUANT_INPUT_2]] : !torch.vtensor<[?,?,?,?],!torch.quint8> -> !torch.vtensor<[?,?,?,?],f32>
3757+
// CHECK-DAG: %[[CONCAT_LIST:.+]] = torch.prim.ListConstruct %[[DEQUANT_INPUT_1]], %[[DEQUANT_INPUT_2]] : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>) -> !torch.list<vtensor>
3758+
// CHECK: %[[AXIS:.+]] = torch.constant.int 1
3759+
// CHECK: %[[CONCAT:.+]] = torch.aten.cat %[[CONCAT_LIST]], %[[AXIS]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
3760+
// CHECK: %[[DTY:.+]] = torch.constant.int 13
3761+
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[CONCAT]], %{{.+}}, %{{.+}}, %[[DTY]] : !torch.vtensor<[?,?,?,?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.quint8>
3762+
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[?,?,?,?],!torch.quint8> -> !torch.vtensor<[?,?,?,?],ui8>
3763+
// CHECK: return %[[OUT]]
3764+
return %0 : !torch.vtensor<[?,?,?,?],ui8>
3765+
}

0 commit comments

Comments
 (0)