@@ -258,11 +258,11 @@ REGISTER_OP("_MklQuantizedConv2D")
258258 .Attr(" data_format: string = 'NHWC'" )
259259 .Attr(" strides: list(int)" )
260260 .Attr(" is_filter_const: bool = true" )
261- .Attr(GetPaddingAttrString ())
261+ .Attr(GetPaddingAttrStringWithExplicit ())
262262 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
263263 .Attr(" padding_list: list(int) = []" )
264264 .SetShapeFn([](InferenceContext* c) {
265- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
265+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
266266 ShapeHandle unused;
267267 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 0 , &unused));
268268 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -307,11 +307,11 @@ REGISTER_OP("_MklQuantizedConv2DAndRequantize")
307307 .Attr(" data_format: string = 'NHWC'" )
308308 .Attr(" strides: list(int)" )
309309 .Attr(" is_filter_const: bool = true" )
310- .Attr(GetPaddingAttrString ())
310+ .Attr(GetPaddingAttrStringWithExplicit ())
311311 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
312312 .Attr(" padding_list: list(int) = []" )
313313 .SetShapeFn([](InferenceContext* c) {
314- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
314+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
315315 ShapeHandle unused;
316316 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 0 , &unused));
317317 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -354,11 +354,11 @@ REGISTER_OP("_MklQuantizedConv2DWithBias")
354354 .Attr(" strides: list(int)" )
355355 .Attr(" is_filter_const: bool = true" )
356356 .Attr(" is_bias_const: bool = true" )
357- .Attr(GetPaddingAttrString ())
357+ .Attr(GetPaddingAttrStringWithExplicit ())
358358 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
359359 .Attr(" padding_list: list(int) = []" )
360360 .SetShapeFn([](InferenceContext* c) {
361- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
361+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
362362 ShapeHandle unused, channel;
363363 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
364364 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -405,11 +405,11 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRequantize")
405405 .Attr(" strides: list(int)" )
406406 .Attr(" is_filter_const: bool = true" )
407407 .Attr(" is_bias_const: bool = true" )
408- .Attr(GetPaddingAttrString ())
408+ .Attr(GetPaddingAttrStringWithExplicit ())
409409 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
410410 .Attr(" padding_list: list(int) = []" )
411411 .SetShapeFn([](InferenceContext* c) {
412- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
412+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
413413 ShapeHandle unused;
414414 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
415415 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -448,11 +448,11 @@ REGISTER_OP("_MklQuantizedConv2DAndRelu")
448448 .Attr(" data_format: string = 'NHWC'" )
449449 .Attr(" strides: list(int)" )
450450 .Attr(" is_filter_const: bool = true" )
451- .Attr(GetPaddingAttrString ())
451+ .Attr(GetPaddingAttrStringWithExplicit ())
452452 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
453453 .Attr(" padding_list: list(int) = []" )
454454 .SetShapeFn([](InferenceContext* c) {
455- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
455+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
456456 ShapeHandle unused, channel;
457457 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 0 , &unused));
458458 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -494,11 +494,11 @@ REGISTER_OP("_MklQuantizedConv2DAndReluAndRequantize")
494494 .Attr(" data_format: string = 'NHWC'" )
495495 .Attr(" strides: list(int)" )
496496 .Attr(" is_filter_const: bool = true" )
497- .Attr(GetPaddingAttrString ())
497+ .Attr(GetPaddingAttrStringWithExplicit ())
498498 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
499499 .Attr(" padding_list: list(int) = []" )
500500 .SetShapeFn([](InferenceContext* c) {
501- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
501+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
502502 ShapeHandle unused;
503503 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 0 , &unused));
504504 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -541,11 +541,12 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRelu")
541541 .Attr(" strides: list(int)" )
542542 .Attr(" is_filter_const: bool = true" )
543543 .Attr(" is_bias_const: bool = true" )
544- .Attr(GetPaddingAttrString ())
544+ .Attr(GetPaddingAttrStringWithExplicit ())
545545 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
546546 .Attr(" padding_list: list(int) = []" )
547+ .Attr(" alpha: float = 0.0" )
547548 .SetShapeFn([](InferenceContext* c) {
548- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
549+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
549550 ShapeHandle unused, channel;
550551 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
551552 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -592,11 +593,12 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
592593 .Attr(" strides: list(int)" )
593594 .Attr(" is_filter_const: bool = true" )
594595 .Attr(" is_bias_const: bool = true" )
595- .Attr(GetPaddingAttrString ())
596+ .Attr(GetPaddingAttrStringWithExplicit ())
596597 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
597598 .Attr(" padding_list: list(int) = []" )
599+ .Attr(" alpha: float = 0.0" )
598600 .SetShapeFn([](InferenceContext* c) {
599- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
601+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
600602 ShapeHandle unused;
601603 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
602604 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -642,11 +644,11 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndRelu")
642644 .Attr(" strides: list(int)" )
643645 .Attr(" is_filter_const: bool = true" )
644646 .Attr(" is_bias_const: bool = true" )
645- .Attr(GetPaddingAttrString ())
647+ .Attr(GetPaddingAttrStringWithExplicit ())
646648 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
647649 .Attr(" padding_list: list(int) = []" )
648650 .SetShapeFn([](InferenceContext* c) {
649- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
651+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
650652 ShapeHandle unused, channel;
651653 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
652654 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -700,11 +702,11 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
700702 .Attr(" strides: list(int)" )
701703 .Attr(" is_filter_const: bool = true" )
702704 .Attr(" is_bias_const: bool = true" )
703- .Attr(GetPaddingAttrString ())
705+ .Attr(GetPaddingAttrStringWithExplicit ())
704706 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
705707 .Attr(" padding_list: list(int) = []" )
706708 .SetShapeFn([](InferenceContext* c) {
707- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
709+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
708710 ShapeHandle unused;
709711 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
710712 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -760,11 +762,121 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
760762 .Attr(" strides: list(int)" )
761763 .Attr(" is_filter_const: bool = true" )
762764 .Attr(" is_bias_const: bool = true" )
765+ .Attr(GetPaddingAttrStringWithExplicit())
766+ .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
767+ .Attr(" padding_list: list(int) = []" )
768+ .SetShapeFn([](InferenceContext* c) {
769+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
770+ ShapeHandle unused;
771+ TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
772+ TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
773+ TF_RETURN_IF_ERROR (c->WithRank (c->input (4 ), 0 , &unused));
774+ TF_RETURN_IF_ERROR (c->WithRankAtMost (c->input (5 ), 1 , &unused));
775+ TF_RETURN_IF_ERROR (c->WithRankAtMost (c->input (6 ), 1 , &unused));
776+ TF_RETURN_IF_ERROR (c->WithRank (c->input (7 ), 0 , &unused));
777+ TF_RETURN_IF_ERROR (c->WithRank (c->input (8 ), 0 , &unused));
778+ c->set_output (1 , c->Scalar ());
779+ c->set_output (2 , c->Scalar ());
780+ return Status::OK ();
781+ });
782+
783+ REGISTER_OP (" _MklQuantizedConv2DWithBiasReluAndSum" )
784+ .Input(" input: Tinput" )
785+ .Input(" filter: Tfilter" )
786+ .Input(" bias: float" )
787+ .Input(" min_input: float" )
788+ .Input(" max_input: float" )
789+ .Input(" min_filter: float" )
790+ .Input(" max_filter: float" )
791+ .Input(" summand: float" )
792+ .Input(" mkl_input: uint8" )
793+ .Input(" mkl_filter: uint8" )
794+ .Input(" mkl_bias: uint8" )
795+ .Input(" mkl_min_input: uint8" )
796+ .Input(" mkl_max_input: uint8" )
797+ .Input(" mkl_min_filter: uint8" )
798+ .Input(" mkl_max_filter: uint8" )
799+ .Input(" mkl_summand: uint8" )
800+ .Output(" output: out_type" )
801+ .Output(" min_output: float" )
802+ .Output(" max_output: float" )
803+ .Output(" mkl_output: uint8" )
804+ .Output(" mkl_min_output: uint8" )
805+ .Output(" mkl_max_output: uint8" )
806+ .Attr(" Tinput: quantizedtype" )
807+ .Attr(" Tfilter: quantizedtype" )
808+ .Attr(" T: quantizedtype" ) // Additional attribute "T" for
809+ // enabling MklToTf conversion
810+ .Attr(" out_type: quantizedtype = DT_QINT32" )
811+ .Attr(" data_format: string = 'NHWC'" )
812+ .Attr(" strides: list(int)" )
813+ .Attr(" is_filter_const: bool = true" )
814+ .Attr(" is_bias_const: bool = true" )
763815 .Attr(GetPaddingAttrString())
764816 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
765817 .Attr(" padding_list: list(int) = []" )
818+ .Attr(" alpha: float = 0.0" )
766819 .SetShapeFn([](InferenceContext* c) {
767- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
820+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
821+ ShapeHandle unused, channel;
822+ TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
823+ TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
824+ TF_RETURN_IF_ERROR (c->WithRank (c->input (4 ), 0 , &unused));
825+ TF_RETURN_IF_ERROR (c->WithRankAtMost (c->input (5 ), 1 , &channel));
826+ TF_RETURN_IF_ERROR (c->WithRankAtMost (c->input (6 ), 1 , &channel));
827+ c->set_output (1 , channel);
828+ c->set_output (2 , channel);
829+ return Status::OK ();
830+ });
831+
832+ REGISTER_OP (" _MklQuantizedConv2DWithBiasReluAndSumAndRequantize" )
833+ .Input(" input: Tinput" )
834+ .Input(" filter: Tfilter" )
835+ .Input(" bias: Tbias" )
836+ .Input(" min_input: float" )
837+ .Input(" max_input: float" )
838+ .Input(" min_filter: float" )
839+ .Input(" max_filter: float" )
840+ .Input(" min_freezed_output: float" )
841+ .Input(" max_freezed_output: float" )
842+ .Input(" summand: Tsummand" )
843+ .Input(" min_summand: float" )
844+ .Input(" max_summand: float" )
845+ .Input(" mkl_input: uint8" )
846+ .Input(" mkl_filter: uint8" )
847+ .Input(" mkl_bias: uint8" )
848+ .Input(" mkl_min_input: uint8" )
849+ .Input(" mkl_max_input: uint8" )
850+ .Input(" mkl_min_filter: uint8" )
851+ .Input(" mkl_max_filter: uint8" )
852+ .Input(" mkl_min_freezed_output: uint8" )
853+ .Input(" mkl_max_freezed_output: uint8" )
854+ .Input(" mkl_summand: uint8" )
855+ .Input(" mkl_min_summand: uint8" )
856+ .Input(" mkl_max_summand: uint8" )
857+ .Output(" output: out_type" )
858+ .Output(" min_output: float" )
859+ .Output(" max_output: float" )
860+ .Output(" mkl_output: uint8" )
861+ .Output(" mkl_min_output: uint8" )
862+ .Output(" mkl_max_output: uint8" )
863+ .Attr(" Tinput: quantizedtype" )
864+ .Attr(" Tfilter: quantizedtype" )
865+ .Attr(" Tbias: {float, qint32}" )
866+ .Attr(" Tsummand: quantizedtype" )
867+ .Attr(" T: quantizedtype" ) // Additional attribute "T" for
868+ // enabling MklToTf conversion
869+ .Attr(" out_type: quantizedtype = DT_QUINT8" )
870+ .Attr(" data_format: string = 'NHWC'" )
871+ .Attr(" strides: list(int)" )
872+ .Attr(" is_filter_const: bool = true" )
873+ .Attr(" is_bias_const: bool = true" )
874+ .Attr(GetPaddingAttrString())
875+ .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
876+ .Attr(" padding_list: list(int) = []" )
877+ .Attr(" alpha: float = 0.0" )
878+ .SetShapeFn([](InferenceContext* c) {
879+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
768880 ShapeHandle unused;
769881 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
770882 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -805,10 +917,10 @@ REGISTER_OP("_MklQuantizedConv2DPerChannel")
805917 .Attr(" data_format: string = 'NHWC'" )
806918 .Attr(" strides: list(int)" )
807919 .Attr(" is_filter_const: bool = false" )
808- .Attr(GetPaddingAttrString ())
920+ .Attr(GetPaddingAttrStringWithExplicit ())
809921 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
810922 .SetShapeFn([](InferenceContext* c) {
811- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
923+ TF_RETURN_IF_ERROR (shape_inference::QuantizedConv2DShape (c));
812924 ShapeHandle unused, channel;
813925 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 0 , &unused));
814926 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -1035,6 +1147,7 @@ REGISTER_OP("_MklQuantizedMatMulWithBiasAndDequantize")
10351147 .Attr(" transpose_a: bool = false" )
10361148 .Attr(" transpose_b: bool = false" )
10371149 .Attr(" input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" )
1150+ .Attr(" is_weight_const: bool = true" )
10381151 .SetShapeFn([](InferenceContext* c) {
10391152 TF_RETURN_IF_ERROR (shape_inference::MatMulShape (c));
10401153 ShapeHandle unused;
@@ -1129,9 +1242,11 @@ REGISTER_OP("_MklQuantizedDepthwiseConv2D")
11291242 .Attr(" is_filter_const: bool = true" )
11301243 .Attr(GetPaddingAttrString())
11311244 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
1245+ .Attr(" padding_list: list(int) = []" )
11321246 .SetShapeFn([](InferenceContext* c) {
11331247 // TODO(bhavanis): Print an error message during the return.
1134- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
1248+ TF_RETURN_IF_ERROR (
1249+ shape_inference::Conv2DShape (c));
11351250 ShapeHandle unused, channel;
11361251 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 0 , &unused));
11371252 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
@@ -1179,8 +1294,10 @@ REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBias")
11791294 .Attr(" is_bias_const: bool = true" )
11801295 .Attr(GetPaddingAttrString())
11811296 .Attr(" dilations: list(int) = [1, 1, 1, 1]" )
1297+ .Attr(" padding_list: list(int) = []" )
11821298 .SetShapeFn([](InferenceContext* c) {
1183- TF_RETURN_IF_ERROR (shape_inference::Conv2DShape (c));
1299+ TF_RETURN_IF_ERROR (
1300+ shape_inference::Conv2DShape (c));
11841301 ShapeHandle unused, channel;
11851302 TF_RETURN_IF_ERROR (c->WithRank (c->input (2 ), 1 , &unused));
11861303 TF_RETURN_IF_ERROR (c->WithRank (c->input (3 ), 0 , &unused));
0 commit comments