Skip to content

Commit 125b6ed

Browse files
marvin-Yuliutongxuan
authored andcommitted
[Op] Add a list of Quantized* and _MklQuantized* ops. (#469)
1 parent 0fe2668 commit 125b6ed

File tree

6 files changed

+264
-26
lines changed

6 files changed

+264
-26
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
op {
2+
graph_op_name: "QuantizedConv2DWithBiasReluAndSum"
3+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
op {
2+
graph_op_name: "QuantizedConv2DWithBiasReluAndSumAndRequantize"
3+
}

tensorflow/core/framework/common_shape_fns.cc

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,36 @@ Status MatMulShape(shape_inference::InferenceContext* c) {
242242
return Status::OK();
243243
}
244244

245+
Status MatMulGradFilterShape(shape_inference::InferenceContext* c) {
246+
ShapeHandle a;
247+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
248+
249+
ShapeHandle b;
250+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
251+
252+
bool transpose_a, transpose_b;
253+
TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
254+
TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
255+
DimensionHandle output_rows = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
256+
DimensionHandle output_cols = c->Dim(b, 1);
257+
258+
if (transpose_b) {
259+
auto tmp = output_rows;
260+
output_rows = output_cols;
261+
output_cols = tmp;
262+
}
263+
264+
// Validate that the inner shapes are compatible.
265+
DimensionHandle inner_a = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
266+
DimensionHandle inner_b = c->Dim(b, 0);
267+
DimensionHandle merged;
268+
TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
269+
270+
c->set_output(0, c->Matrix(output_rows, output_cols));
271+
c->set_output(1, c->Vector(output_cols));
272+
return Status::OK();
273+
}
274+
245275
namespace {
246276

247277
// Validate that an Einsum subscript contains exactly one or zero ellipsis; and
@@ -663,7 +693,8 @@ Status ShapeFromDimensions(DimensionHandle batch_dim,
663693
namespace {
664694

665695
Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
666-
bool supports_explicit_padding) {
696+
bool supports_explicit_padding,
697+
string padding_attr_name = "explicit_paddings") {
667698
string data_format_str, filter_format_str;
668699
if (!c->GetAttr("data_format", &data_format_str).ok()) {
669700
data_format_str = "NHWC";
@@ -827,6 +858,11 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
827858
return Conv2DShapeImpl(c, false);
828859
}
829860

861+
// Shape function for QuantizedConv2D-like operations
862+
Status QuantizedConv2DShape(shape_inference::InferenceContext* c) {
863+
return Conv2DShapeImpl(c, true, "padding_list");
864+
}
865+
830866
// TODO(mjanusz): Unify all conv/pooling shape functions.
831867
Status Conv3DShape(shape_inference::InferenceContext* c) {
832868
ShapeHandle input_shape;

tensorflow/core/framework/common_shape_fns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
225225

226226
// Shape function for MatMul-like operations.
227227
Status MatMulShape(shape_inference::InferenceContext* c);
228+
Status MatMulGradFilterShape(shape_inference::InferenceContext* c);
228229

229230
// Shape function for Batched MatMul-like operations with broadcasting across
230231
// batch dimensions.
@@ -249,6 +250,9 @@ Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c);
249250
// padding.
250251
Status Conv2DShape(shape_inference::InferenceContext* c);
251252

253+
// Shape function for QuantizedConv2D-like operations
254+
Status QuantizedConv2DShape(shape_inference::InferenceContext* c);
255+
252256
// Shape function for Conv3D-like operations.
253257
Status Conv3DShape(shape_inference::InferenceContext* c);
254258

tensorflow/core/ops/mkl_nn_ops.cc

Lines changed: 142 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,11 @@ REGISTER_OP("_MklQuantizedConv2D")
258258
.Attr("data_format: string = 'NHWC'")
259259
.Attr("strides: list(int)")
260260
.Attr("is_filter_const: bool = true")
261-
.Attr(GetPaddingAttrString())
261+
.Attr(GetPaddingAttrStringWithExplicit())
262262
.Attr("dilations: list(int) = [1, 1, 1, 1]")
263263
.Attr("padding_list: list(int) = []")
264264
.SetShapeFn([](InferenceContext* c) {
265-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
265+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
266266
ShapeHandle unused;
267267
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
268268
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -307,11 +307,11 @@ REGISTER_OP("_MklQuantizedConv2DAndRequantize")
307307
.Attr("data_format: string = 'NHWC'")
308308
.Attr("strides: list(int)")
309309
.Attr("is_filter_const: bool = true")
310-
.Attr(GetPaddingAttrString())
310+
.Attr(GetPaddingAttrStringWithExplicit())
311311
.Attr("dilations: list(int) = [1, 1, 1, 1]")
312312
.Attr("padding_list: list(int) = []")
313313
.SetShapeFn([](InferenceContext* c) {
314-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
314+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
315315
ShapeHandle unused;
316316
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
317317
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -354,11 +354,11 @@ REGISTER_OP("_MklQuantizedConv2DWithBias")
354354
.Attr("strides: list(int)")
355355
.Attr("is_filter_const: bool = true")
356356
.Attr("is_bias_const: bool = true")
357-
.Attr(GetPaddingAttrString())
357+
.Attr(GetPaddingAttrStringWithExplicit())
358358
.Attr("dilations: list(int) = [1, 1, 1, 1]")
359359
.Attr("padding_list: list(int) = []")
360360
.SetShapeFn([](InferenceContext* c) {
361-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
361+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
362362
ShapeHandle unused, channel;
363363
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
364364
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -405,11 +405,11 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRequantize")
405405
.Attr("strides: list(int)")
406406
.Attr("is_filter_const: bool = true")
407407
.Attr("is_bias_const: bool = true")
408-
.Attr(GetPaddingAttrString())
408+
.Attr(GetPaddingAttrStringWithExplicit())
409409
.Attr("dilations: list(int) = [1, 1, 1, 1]")
410410
.Attr("padding_list: list(int) = []")
411411
.SetShapeFn([](InferenceContext* c) {
412-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
412+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
413413
ShapeHandle unused;
414414
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
415415
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -448,11 +448,11 @@ REGISTER_OP("_MklQuantizedConv2DAndRelu")
448448
.Attr("data_format: string = 'NHWC'")
449449
.Attr("strides: list(int)")
450450
.Attr("is_filter_const: bool = true")
451-
.Attr(GetPaddingAttrString())
451+
.Attr(GetPaddingAttrStringWithExplicit())
452452
.Attr("dilations: list(int) = [1, 1, 1, 1]")
453453
.Attr("padding_list: list(int) = []")
454454
.SetShapeFn([](InferenceContext* c) {
455-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
455+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
456456
ShapeHandle unused, channel;
457457
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
458458
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -494,11 +494,11 @@ REGISTER_OP("_MklQuantizedConv2DAndReluAndRequantize")
494494
.Attr("data_format: string = 'NHWC'")
495495
.Attr("strides: list(int)")
496496
.Attr("is_filter_const: bool = true")
497-
.Attr(GetPaddingAttrString())
497+
.Attr(GetPaddingAttrStringWithExplicit())
498498
.Attr("dilations: list(int) = [1, 1, 1, 1]")
499499
.Attr("padding_list: list(int) = []")
500500
.SetShapeFn([](InferenceContext* c) {
501-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
501+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
502502
ShapeHandle unused;
503503
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
504504
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -541,11 +541,12 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRelu")
541541
.Attr("strides: list(int)")
542542
.Attr("is_filter_const: bool = true")
543543
.Attr("is_bias_const: bool = true")
544-
.Attr(GetPaddingAttrString())
544+
.Attr(GetPaddingAttrStringWithExplicit())
545545
.Attr("dilations: list(int) = [1, 1, 1, 1]")
546546
.Attr("padding_list: list(int) = []")
547+
.Attr("alpha: float = 0.0")
547548
.SetShapeFn([](InferenceContext* c) {
548-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
549+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
549550
ShapeHandle unused, channel;
550551
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
551552
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -592,11 +593,12 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
592593
.Attr("strides: list(int)")
593594
.Attr("is_filter_const: bool = true")
594595
.Attr("is_bias_const: bool = true")
595-
.Attr(GetPaddingAttrString())
596+
.Attr(GetPaddingAttrStringWithExplicit())
596597
.Attr("dilations: list(int) = [1, 1, 1, 1]")
597598
.Attr("padding_list: list(int) = []")
599+
.Attr("alpha: float = 0.0")
598600
.SetShapeFn([](InferenceContext* c) {
599-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
601+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
600602
ShapeHandle unused;
601603
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
602604
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -642,11 +644,11 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndRelu")
642644
.Attr("strides: list(int)")
643645
.Attr("is_filter_const: bool = true")
644646
.Attr("is_bias_const: bool = true")
645-
.Attr(GetPaddingAttrString())
647+
.Attr(GetPaddingAttrStringWithExplicit())
646648
.Attr("dilations: list(int) = [1, 1, 1, 1]")
647649
.Attr("padding_list: list(int) = []")
648650
.SetShapeFn([](InferenceContext* c) {
649-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
651+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
650652
ShapeHandle unused, channel;
651653
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
652654
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -700,11 +702,11 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
700702
.Attr("strides: list(int)")
701703
.Attr("is_filter_const: bool = true")
702704
.Attr("is_bias_const: bool = true")
703-
.Attr(GetPaddingAttrString())
705+
.Attr(GetPaddingAttrStringWithExplicit())
704706
.Attr("dilations: list(int) = [1, 1, 1, 1]")
705707
.Attr("padding_list: list(int) = []")
706708
.SetShapeFn([](InferenceContext* c) {
707-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
709+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
708710
ShapeHandle unused;
709711
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
710712
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -760,11 +762,121 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
760762
.Attr("strides: list(int)")
761763
.Attr("is_filter_const: bool = true")
762764
.Attr("is_bias_const: bool = true")
765+
.Attr(GetPaddingAttrStringWithExplicit())
766+
.Attr("dilations: list(int) = [1, 1, 1, 1]")
767+
.Attr("padding_list: list(int) = []")
768+
.SetShapeFn([](InferenceContext* c) {
769+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
770+
ShapeHandle unused;
771+
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
772+
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
773+
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
774+
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused));
775+
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused));
776+
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
777+
TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
778+
c->set_output(1, c->Scalar());
779+
c->set_output(2, c->Scalar());
780+
return Status::OK();
781+
});
782+
783+
REGISTER_OP("_MklQuantizedConv2DWithBiasReluAndSum")
784+
.Input("input: Tinput")
785+
.Input("filter: Tfilter")
786+
.Input("bias: float")
787+
.Input("min_input: float")
788+
.Input("max_input: float")
789+
.Input("min_filter: float")
790+
.Input("max_filter: float")
791+
.Input("summand: float")
792+
.Input("mkl_input: uint8")
793+
.Input("mkl_filter: uint8")
794+
.Input("mkl_bias: uint8")
795+
.Input("mkl_min_input: uint8")
796+
.Input("mkl_max_input: uint8")
797+
.Input("mkl_min_filter: uint8")
798+
.Input("mkl_max_filter: uint8")
799+
.Input("mkl_summand: uint8")
800+
.Output("output: out_type")
801+
.Output("min_output: float")
802+
.Output("max_output: float")
803+
.Output("mkl_output: uint8")
804+
.Output("mkl_min_output: uint8")
805+
.Output("mkl_max_output: uint8")
806+
.Attr("Tinput: quantizedtype")
807+
.Attr("Tfilter: quantizedtype")
808+
.Attr("T: quantizedtype") // Additional attribute "T" for
809+
// enabling MklToTf conversion
810+
.Attr("out_type: quantizedtype = DT_QINT32")
811+
.Attr("data_format: string = 'NHWC'")
812+
.Attr("strides: list(int)")
813+
.Attr("is_filter_const: bool = true")
814+
.Attr("is_bias_const: bool = true")
763815
.Attr(GetPaddingAttrString())
764816
.Attr("dilations: list(int) = [1, 1, 1, 1]")
765817
.Attr("padding_list: list(int) = []")
818+
.Attr("alpha: float = 0.0")
766819
.SetShapeFn([](InferenceContext* c) {
767-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
820+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
821+
ShapeHandle unused, channel;
822+
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
823+
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
824+
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
825+
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
826+
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
827+
c->set_output(1, channel);
828+
c->set_output(2, channel);
829+
return Status::OK();
830+
});
831+
832+
REGISTER_OP("_MklQuantizedConv2DWithBiasReluAndSumAndRequantize")
833+
.Input("input: Tinput")
834+
.Input("filter: Tfilter")
835+
.Input("bias: Tbias")
836+
.Input("min_input: float")
837+
.Input("max_input: float")
838+
.Input("min_filter: float")
839+
.Input("max_filter: float")
840+
.Input("min_freezed_output: float")
841+
.Input("max_freezed_output: float")
842+
.Input("summand: Tsummand")
843+
.Input("min_summand: float")
844+
.Input("max_summand: float")
845+
.Input("mkl_input: uint8")
846+
.Input("mkl_filter: uint8")
847+
.Input("mkl_bias: uint8")
848+
.Input("mkl_min_input: uint8")
849+
.Input("mkl_max_input: uint8")
850+
.Input("mkl_min_filter: uint8")
851+
.Input("mkl_max_filter: uint8")
852+
.Input("mkl_min_freezed_output: uint8")
853+
.Input("mkl_max_freezed_output: uint8")
854+
.Input("mkl_summand: uint8")
855+
.Input("mkl_min_summand: uint8")
856+
.Input("mkl_max_summand: uint8")
857+
.Output("output: out_type")
858+
.Output("min_output: float")
859+
.Output("max_output: float")
860+
.Output("mkl_output: uint8")
861+
.Output("mkl_min_output: uint8")
862+
.Output("mkl_max_output: uint8")
863+
.Attr("Tinput: quantizedtype")
864+
.Attr("Tfilter: quantizedtype")
865+
.Attr("Tbias: {float, qint32}")
866+
.Attr("Tsummand: quantizedtype")
867+
.Attr("T: quantizedtype") // Additional attribute "T" for
868+
// enabling MklToTf conversion
869+
.Attr("out_type: quantizedtype = DT_QUINT8")
870+
.Attr("data_format: string = 'NHWC'")
871+
.Attr("strides: list(int)")
872+
.Attr("is_filter_const: bool = true")
873+
.Attr("is_bias_const: bool = true")
874+
.Attr(GetPaddingAttrString())
875+
.Attr("dilations: list(int) = [1, 1, 1, 1]")
876+
.Attr("padding_list: list(int) = []")
877+
.Attr("alpha: float = 0.0")
878+
.SetShapeFn([](InferenceContext* c) {
879+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
768880
ShapeHandle unused;
769881
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
770882
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -805,10 +917,10 @@ REGISTER_OP("_MklQuantizedConv2DPerChannel")
805917
.Attr("data_format: string = 'NHWC'")
806918
.Attr("strides: list(int)")
807919
.Attr("is_filter_const: bool = false")
808-
.Attr(GetPaddingAttrString())
920+
.Attr(GetPaddingAttrStringWithExplicit())
809921
.Attr("dilations: list(int) = [1, 1, 1, 1]")
810922
.SetShapeFn([](InferenceContext* c) {
811-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
923+
TF_RETURN_IF_ERROR(shape_inference::QuantizedConv2DShape(c));
812924
ShapeHandle unused, channel;
813925
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
814926
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -1035,6 +1147,7 @@ REGISTER_OP("_MklQuantizedMatMulWithBiasAndDequantize")
10351147
.Attr("transpose_a: bool = false")
10361148
.Attr("transpose_b: bool = false")
10371149
.Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
1150+
.Attr("is_weight_const: bool = true")
10381151
.SetShapeFn([](InferenceContext* c) {
10391152
TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
10401153
ShapeHandle unused;
@@ -1129,9 +1242,11 @@ REGISTER_OP("_MklQuantizedDepthwiseConv2D")
11291242
.Attr("is_filter_const: bool = true")
11301243
.Attr(GetPaddingAttrString())
11311244
.Attr("dilations: list(int) = [1, 1, 1, 1]")
1245+
.Attr("padding_list: list(int) = []")
11321246
.SetShapeFn([](InferenceContext* c) {
11331247
// TODO(bhavanis): Print an error message during the return.
1134-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1248+
TF_RETURN_IF_ERROR(
1249+
shape_inference::Conv2DShape(c));
11351250
ShapeHandle unused, channel;
11361251
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
11371252
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
@@ -1179,8 +1294,10 @@ REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBias")
11791294
.Attr("is_bias_const: bool = true")
11801295
.Attr(GetPaddingAttrString())
11811296
.Attr("dilations: list(int) = [1, 1, 1, 1]")
1297+
.Attr("padding_list: list(int) = []")
11821298
.SetShapeFn([](InferenceContext* c) {
1183-
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1299+
TF_RETURN_IF_ERROR(
1300+
shape_inference::Conv2DShape(c));
11841301
ShapeHandle unused, channel;
11851302
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
11861303
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));

0 commit comments

Comments
 (0)