@@ -645,4 +645,107 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
645645 y);
646646 return success ();
647647 });
648+ patterns.onOp (
649+ " QLinearConcat" , 1 ,
650+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
651+ Location loc = binder.getLoc ();
652+ Torch::ValueTensorType resultType;
653+ SmallVector<Value> operands;
654+ int64_t axis;
655+ if (binder.tensorOperandsList (operands) ||
656+ binder.s64IntegerAttr (axis, " axis" ) ||
657+ binder.tensorResultType (resultType))
658+ return failure ();
659+
660+ SmallVector<Value> inputs, inputScales, inputZeroPoints;
661+ for (unsigned i = 2 ; i < operands.size (); i = i + 3 ) {
662+ inputs.push_back (operands[i]);
663+ inputScales.push_back (operands[i + 1 ]);
664+ inputZeroPoints.push_back (operands[i + 2 ]);
665+ }
666+
667+ unsigned numInputs = (operands.size () - 2 ) / 3 ;
668+ if (!(llvm::all_equal ({inputs.size (), inputScales.size (),
669+ inputZeroPoints.size ()}) &&
670+ inputs.size () == numInputs))
671+ return rewriter.notifyMatchFailure (
672+ binder.op , " Incompatible number of input operands, scales and/or "
673+ " zero-points" );
674+
675+ auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
676+ Value zp) -> Value {
677+ auto ty = cast<Torch::ValueTensorType>(v.getType ());
678+ auto newTy = getQTorchTypeFromTorchIntType (ty);
679+ return rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
680+ binder.getLoc (), newTy, v, scale, zp);
681+ };
682+
683+ // Preparing the quantized inputs.
684+ SmallVector<Value> quantizedInputs;
685+ for (unsigned i = 0 ; i < numInputs; i++) {
686+ Value scale, zeroPoint;
687+ if (failed (extractPerTensorQuantizationArguments (
688+ rewriter, loc, /* scale=*/ inputScales[i],
689+ /* zero_point=*/ inputZeroPoints[i], scale, zeroPoint)))
690+ return rewriter.notifyMatchFailure (
691+ binder.op , " Incompatible scale and zero-points argument for "
692+ " per-tensor quantization" );
693+
694+ quantizedInputs.push_back (makePerTensor (inputs[i], scale, zeroPoint));
695+ }
696+
697+ // Dequantizing the inputs.
698+ SmallVector<Value> dequantizedInputs;
699+ for (unsigned i = 0 ; i < numInputs; i++) {
700+ Torch::ValueTensorType inputTy =
701+ dyn_cast<Torch::ValueTensorType>(quantizedInputs[i].getType ());
702+ if (!inputTy || !inputTy.hasSizes ())
703+ return rewriter.notifyMatchFailure (
704+ binder.op , " Expected tensor input operands to be concatenated "
705+ " to have sizes" );
706+
707+ inputTy = rewriter.getType <Torch::ValueTensorType>(
708+ inputTy.getOptionalSizes (), rewriter.getF32Type ());
709+ dequantizedInputs.push_back (
710+ rewriter.create <Torch::AtenDequantizeSelfOp>(loc, inputTy,
711+ quantizedInputs[i]));
712+ }
713+
714+ // Concatenating the inputs.
715+ Type listElemType =
716+ cast<Torch::BaseTensorType>(dequantizedInputs[0 ].getType ())
717+ .getWithSizesAndDtype (/* optionalSizes=*/ std::nullopt ,
718+ /* optionalDtype=*/ nullptr );
719+ Type listType = Torch::ListType::get (listElemType);
720+ Value tensorList = rewriter.create <Torch::PrimListConstructOp>(
721+ binder.op ->getLoc (), listType, dequantizedInputs);
722+ Value cstAxis = rewriter.create <Torch::ConstantIntOp>(
723+ loc, rewriter.getI64IntegerAttr (axis));
724+ auto concatTy = rewriter.getType <Torch::ValueTensorType>(
725+ resultType.getOptionalSizes (), rewriter.getF32Type ());
726+ Value concat = rewriter.create <Torch::AtenCatOp>(loc, concatTy,
727+ tensorList, cstAxis);
728+
729+ // Quantizing the result of concatenated inputs.
730+ Value yScale, yZp;
731+ if (failed (extractPerTensorQuantizationArguments (
732+ rewriter, loc, /* scale=*/ operands[0 ],
733+ /* zero_point=*/ operands[1 ], yScale, yZp)))
734+ return rewriter.notifyMatchFailure (
735+ binder.op , " Incompatible scale and zero-points argument for "
736+ " per-tensor quantization" );
737+ Torch::ValueTensorType yTy = dyn_cast<Torch::ValueTensorType>(
738+ getQTorchTypeFromTorchIntType (resultType));
739+ Value dtyVal = rewriter.create <Torch::ConstantIntOp>(
740+ loc, rewriter.getType <Torch::IntType>(),
741+ rewriter.getIntegerAttr (
742+ rewriter.getIntegerType (64 ),
743+ static_cast <int64_t >(
744+ Torch::getScalarTypeForType (yTy.getDtype ()))));
745+ Value result = rewriter.create <Torch::AtenQuantizePerTensorOp>(
746+ loc, yTy, concat, yScale, yZp, dtyVal);
747+ rewriter.replaceOpWithNewOp <Torch::AtenIntReprOp>(binder.op , resultType,
748+ result);
749+ return success ();
750+ });
648751}
0 commit comments