|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +""" |
| 5 | +tfl_math |
| 6 | +""" |
| 7 | + |
| 8 | +import logging |
| 9 | +import numpy as np |
| 10 | +from tf2onnx.handler import tfl_op |
| 11 | +from tf2onnx import utils |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name |
| 17 | + |
| 18 | + |
| 19 | +def separate_fused_activation_function(ctx, node): |
| 20 | + activation_fn = node.attr['fused_activation_function'].s |
| 21 | + del node.attr['fused_activation_function'] |
| 22 | + if activation_fn == b'RELU': |
| 23 | + ctx.insert_new_node_on_output("Relu", node.output[0]) |
| 24 | + elif activation_fn == b'RELU6': |
| 25 | + new_node = ctx.insert_new_node_on_output("Relu6", node.output[0]) |
| 26 | + new_node.skip_conversion = False |
| 27 | + elif activation_fn == b'TANH': |
| 28 | + ctx.insert_new_node_on_output("Tanh", node.output[0]) |
| 29 | + else: |
| 30 | + # TODO: SIGN_BIT and RELU_N1_TO_1 not supported yet |
| 31 | + utils.make_sure(activation_fn == b'NONE', "Unsupported fused activation function %s on node %s", |
| 32 | + activation_fn, node.name) |
| 33 | + |
| 34 | +@tfl_op(["TFL_ADD"], tf_op="Add") |
| 35 | +class TflAdd: |
| 36 | + @classmethod |
| 37 | + def to_tf(cls, ctx, node, **kwargs): |
| 38 | + separate_fused_activation_function(ctx, node) |
| 39 | + |
| 40 | +@tfl_op(["TFL_SUB"], tf_op="Sub") |
| 41 | +class TflSub: |
| 42 | + @classmethod |
| 43 | + def to_tf(cls, ctx, node, **kwargs): |
| 44 | + separate_fused_activation_function(ctx, node) |
| 45 | + |
| 46 | +@tfl_op(["TFL_MUL"], tf_op="Mul") |
| 47 | +class TflMul: |
| 48 | + @classmethod |
| 49 | + def to_tf(cls, ctx, node, **kwargs): |
| 50 | + separate_fused_activation_function(ctx, node) |
| 51 | + |
| 52 | +@tfl_op(["TFL_DIV"], tf_op="Div") |
| 53 | +class TflDiv: |
| 54 | + @classmethod |
| 55 | + def to_tf(cls, ctx, node, **kwargs): |
| 56 | + separate_fused_activation_function(ctx, node) |
| 57 | + |
| 58 | +@tfl_op(["TFL_LOGISTIC"], tf_op="Sigmoid") |
| 59 | +class TflLogistic: |
| 60 | + @classmethod |
| 61 | + def to_tf(cls, ctx, node, **kwargs): |
| 62 | + pass |
| 63 | + |
| 64 | +@tfl_op(["TFL_REDUCE_MAX"], tf_op="Max") |
| 65 | +@tfl_op(["TFL_REDUCE_ANY"], tf_op="Any") |
| 66 | +@tfl_op(["TFL_REDUCE_PROD"], tf_op="Prod") |
| 67 | +class TflReduceOp: |
| 68 | + @classmethod |
| 69 | + def to_tf(cls, ctx, node, **kwargs): |
| 70 | + pass |
| 71 | + |
| 72 | +@tfl_op(["TFL_LOCAL_RESPONSE_NORMALIZATION"], tf_op="LRN") |
| 73 | +class TFlLocalResponseNormalizationOp: |
| 74 | + @classmethod |
| 75 | + def to_tf(cls, ctx, node, **kwargs): |
| 76 | + node.attr["depth_radius"] = node.attr["radius"] |
| 77 | + del node.attr["radius"] |
| 78 | + |
| 79 | +@tfl_op(["TFL_RANGE"], tf_op="Range") |
| 80 | +class TflRangeOp: |
| 81 | + @classmethod |
| 82 | + def to_tf(cls, ctx, node, **kwargs): |
| 83 | + node.set_attr("Tidx", ctx.get_dtype(node.output[0])) |
| 84 | + |
| 85 | +@tfl_op(["TFL_QUANTIZE"], onnx_op="QuantizeLinear") |
| 86 | +class TflQuantizeOp: |
| 87 | + @classmethod |
| 88 | + def version_10(cls, ctx, node, **kwargs): |
| 89 | + scale = node.get_attr_value('scale') |
| 90 | + zero_point = node.get_attr_value('zero_point') |
| 91 | + axis = node.get_attr_value('quantized_dimension') |
| 92 | + np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.output[0])) |
| 93 | + if len(scale) > 1 or len(zero_point) > 1: |
| 94 | + node.set_attr("axis", axis) |
| 95 | + scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32)) |
| 96 | + zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type)) |
| 97 | + ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]]) |
| 98 | + del node.attr["scale"] |
| 99 | + del node.attr["zero_point"] |
| 100 | + del node.attr["quantized_dimension"] |
| 101 | + |
| 102 | +@tfl_op(["TFL_DEQUANTIZE"], onnx_op="DequantizeLinear") |
| 103 | +class TflDequantizeOp: |
| 104 | + @classmethod |
| 105 | + def version_10(cls, ctx, node, **kwargs): |
| 106 | + scale = node.get_attr_value('scale') |
| 107 | + zero_point = node.get_attr_value('zero_point') |
| 108 | + axis = node.get_attr_value('quantized_dimension') |
| 109 | + np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0])) |
| 110 | + if len(scale) > 1 or len(zero_point) > 1: |
| 111 | + utils.make_sure(ctx.opset >= 13, "Opset 13 is required for per-axis quantization") |
| 112 | + node.set_attr("axis", axis) |
| 113 | + scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale, dtype=np.float32)) |
| 114 | + zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point, dtype=np_q_type)) |
| 115 | + else: |
| 116 | + scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32)) |
| 117 | + zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type)) |
| 118 | + ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]]) |
| 119 | + del node.attr["scale"] |
| 120 | + del node.attr["zero_point"] |
| 121 | + del node.attr["quantized_dimension"] |
| 122 | + |
| 123 | +def dynamic_quantize_inputs(ctx, node): |
| 124 | + if ctx.opset < 11: |
| 125 | + logger.warning("Opset 11 is required for asymmetric_quantize_inputs of node %s", node.name) |
| 126 | + return |
| 127 | + for i in range(len(node.input)): |
| 128 | + # Don't quantize inputs that are already quantized |
| 129 | + if node.inputs[i].type in ["DequantizeLinear", "TFL_DEQUANTIZE"]: |
| 130 | + continue |
| 131 | + dyn_quant = ctx.make_node("DynamicQuantizeLinear", [node.input[i]], output_count=3, op_name_scope=node.name) |
| 132 | + dyn_quant.skip_conversion = True |
| 133 | + dequant = ctx.make_node("DequantizeLinear", dyn_quant.output, op_name_scope=node.name) |
| 134 | + dequant.skip_conversion = True |
| 135 | + ctx.replace_input(node, node.input[i], dequant.output[0], input_index=i) |
| 136 | + |
| 137 | +@tfl_op(["TFL_FULLY_CONNECTED"]) |
| 138 | +class TflFullyConnectedOp: |
| 139 | + @classmethod |
| 140 | + def to_tf(cls, ctx, node, **kwargs): |
| 141 | + separate_fused_activation_function(ctx, node) |
| 142 | + utils.make_sure(node.attr['weights_format'].s == b'DEFAULT', |
| 143 | + "Only default weights format supported for fully connected op") |
| 144 | + utils.make_sure(node.attr['keep_num_dims'].i == 0, |
| 145 | + "Only keep_num_dims=False supported for fully connected op") |
| 146 | + if node.attr['asymmetric_quantize_inputs'].i == 1: |
| 147 | + dynamic_quantize_inputs(ctx, node) |
| 148 | + |
| 149 | + transpose_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[1], |
| 150 | + name=None, input_index=1, perm=[1, 0]) |
| 151 | + transpose_node.skip_conversion = True |
| 152 | + node.set_attr("transpose_a", 0) |
| 153 | + node.set_attr("transpose_b", 0) |
| 154 | + node.type = "MatMul" |
| 155 | + |
| 156 | + if len(node.input) == 3: |
| 157 | + # FIXME: Add a test for this |
| 158 | + bias_inp = node.input[2] |
| 159 | + ctx.replace_inputs(node, node.input[:2]) |
| 160 | + add_node = ctx.insert_new_node_on_output("Add", node.output[0], inputs=[node.output[0], bias_inp]) |
| 161 | + add_node.skip_conversion = True |
| 162 | + |
| 163 | + del node.attr["weights_format"] |
| 164 | + del node.attr["keep_num_dims"] |
| 165 | + del node.attr["asymmetric_quantize_inputs"] |
| 166 | + |
| 167 | +@tfl_op(["TFL_SOFTMAX"], tf_op="Softmax") |
| 168 | +class TFlSoftmaxOp: |
| 169 | + @classmethod |
| 170 | + def to_tf(cls, ctx, node, **kwargs): |
| 171 | + beta = node.get_attr_value("beta") |
| 172 | + beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32)) |
| 173 | + mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name)) |
| 174 | + ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]]) |
0 commit comments