@@ -271,11 +271,11 @@ static LogicalResult verifyConvOp(T op) {
271271 }
272272 }
273273
274- bool inputIsQuant = ! llvm::isa<FloatType>(inputEType);
275- bool weightIsQuant = ! llvm::isa<FloatType>(weightEType);
274+ bool inputIsFloat = llvm::isa<FloatType>(inputEType);
275+ bool weightIsFloat = llvm::isa<FloatType>(weightEType);
276276
277- // Either both must be quantized or both unquantized .
278- if (inputIsQuant != weightIsQuant ) {
277+ // Either both must be float or both non-float .
278+ if (inputIsFloat != weightIsFloat ) {
279279 op.emitOpError (
280280 " expect both input and weight to be float or not together, got " )
281281 << inputEType << " and " << weightEType;
@@ -527,7 +527,12 @@ static void buildTransConvOpWithQuantInfo(
527527 auto quantAttr = ::buildConvOpQuantizationAttr (builder, input, weight);
528528
529529 if (quantAttr) {
530- result.addAttribute (" quantization_info" , quantAttr);
530+ result.addAttribute (" input_zp" ,
531+ builder.getI32IntegerAttr (
532+ static_cast <int32_t >(quantAttr.getInputZp ())));
533+ result.addAttribute (" weight_zp" ,
534+ builder.getI32IntegerAttr (
535+ static_cast <int32_t >(quantAttr.getWeightZp ())));
531536 result.addTypes (
532537 buildConvOpResultTypeInfo (builder, outputType, input, weight));
533538 } else {
@@ -563,7 +568,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
563568 auto quantAttr = ::buildMatMulOpQuantizationAttr (builder, a, b);
564569
565570 if (quantAttr) {
566- result.addAttribute (" quantization_info" , quantAttr);
571+ result.addAttribute (" a_zp" , builder.getI32IntegerAttr (
572+ static_cast <int32_t >(quantAttr.getAZp ())));
573+ result.addAttribute (" b_zp" , builder.getI32IntegerAttr (
574+ static_cast <int32_t >(quantAttr.getBZp ())));
567575
568576 auto inputType = llvm::dyn_cast<ShapedType>(a.getType ());
569577 assert (inputType && " Input must be a shaped tensor type!" );
@@ -603,8 +611,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
603611 result.addAttribute (" pad" , pad);
604612 result.addAttribute (" acc_type" , accType);
605613 auto quantAttr = buildUnaryOpQuantizationAttr (builder, input, outputType);
606- if (quantAttr)
607- result.addAttribute (" quantization_info" , quantAttr);
614+ if (quantAttr) {
615+ result.addAttribute (" input_zp" ,
616+ builder.getI32IntegerAttr (
617+ static_cast <int32_t >(quantAttr.getInputZp ())));
618+ result.addAttribute (" output_zp" ,
619+ builder.getI32IntegerAttr (
620+ static_cast <int32_t >(quantAttr.getOutputZp ())));
621+ }
608622 result.types .push_back (outputType);
609623}
610624
@@ -616,8 +630,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
616630 Value input) {
617631 result.addOperands (input);
618632 auto quantAttr = buildUnaryOpQuantizationAttr (builder, input, outputType);
619- if (quantAttr)
620- result.addAttribute (" quantization_info" , quantAttr);
633+ if (quantAttr) {
634+ // note: negateOp has attributes input1_zp and output_zp
635+ result.addAttribute (" input1_zp" ,
636+ builder.getI32IntegerAttr (
637+ static_cast <int32_t >(quantAttr.getInputZp ())));
638+ result.addAttribute (" output_zp" ,
639+ builder.getI32IntegerAttr (
640+ static_cast <int32_t >(quantAttr.getOutputZp ())));
641+ }
621642 result.types .push_back (outputType);
622643}
623644
@@ -629,8 +650,11 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
629650 Value paddings) {
630651 result.addOperands ({input, paddings});
631652 auto quantAttr = buildPadOpQuantizationAttr (builder, input);
632- if (quantAttr)
633- result.addAttribute (" quantization_info" , quantAttr);
653+ if (quantAttr) {
654+ result.addAttribute (" input_zp" ,
655+ builder.getI32IntegerAttr (
656+ static_cast <int32_t >(quantAttr.getInputZp ())));
657+ }
634658 result.types .push_back (outputType);
635659}
636660
@@ -643,8 +667,11 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
643667 Value padConst) {
644668 result.addOperands ({input, paddings, padConst});
645669 auto quantAttr = buildPadOpQuantizationAttr (builder, input);
646- if (quantAttr)
647- result.addAttribute (" quantization_info" , quantAttr);
670+ if (quantAttr) {
671+ result.addAttribute (" input_zp" ,
672+ builder.getI32IntegerAttr (
673+ static_cast <int32_t >(quantAttr.getInputZp ())));
674+ }
648675 result.types .push_back (outputType);
649676}
650677
@@ -898,9 +925,8 @@ LogicalResult FullyConnectedOp::verify() {
898925
899926 // Quantized type must have constructed the quantizationattr, and unquantized
900927 // types should not have a quantizationattr.
901- if ((inputIsQuant && !getQuantizationInfo ()) ||
902- (!inputIsQuant && getQuantizationInfo ())) {
903- emitOpError (" quantizationattr is required for quantized type, and not "
928+ if ((inputIsQuant && !getInputZp ()) || (!inputIsQuant && getInputZp ())) {
929+ emitOpError (" input zero point is required for quantized type, and not "
904930 " allowed for float type" );
905931 return failure ();
906932 }
0 commit comments