|
| 1 | +########################################################################### |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# Licensed under the MIT License. See License.txt in the project root for |
| 4 | +# license information. |
| 5 | +########################################################################### |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import itertools |
| 9 | +import onnx |
| 10 | +from onnx import helper |
| 11 | +from onnx import onnx_pb as onnx_proto |
| 12 | + |
| 13 | + |
| 14 | +def _npfloat16_to_int(np_list): |
| 15 | + ''' |
| 16 | + Convert numpy float16 to python int. |
| 17 | +
|
| 18 | + :param np_list: numpy float16 list |
| 19 | + :return int_list: python int list |
| 20 | + ''' |
| 21 | + return [int(bin(_.view('H'))[2:].zfill(16), 2) for _ in np_list] |
| 22 | + |
| 23 | + |
| 24 | +def convert_tensor_float_to_float16(tensor): |
| 25 | + ''' |
| 26 | + Convert tensor float to float16. |
| 27 | +
|
| 28 | + :param tensor: TensorProto object |
| 29 | + :return tensor_float16: converted TensorProto object |
| 30 | +
|
| 31 | + Example: |
| 32 | +
|
| 33 | + :: |
| 34 | +
|
| 35 | + from onnxmltools.utils.float16_converter import convert_tensor_float_to_float16 |
| 36 | + new_tensor = convert_tensor_float_to_float16(tensor) |
| 37 | +
|
| 38 | + ''' |
| 39 | + if not isinstance(tensor, onnx_proto.TensorProto): |
| 40 | + raise ValueError('Expected input type is an ONNX TensorProto but got %s' % type(tensor)) |
| 41 | + |
| 42 | + if tensor.data_type == onnx_proto.TensorProto.FLOAT: |
| 43 | + tensor.data_type = onnx_proto.TensorProto.FLOAT16 |
| 44 | + # convert float_data (float type) to float16 and write to int32_data |
| 45 | + if tensor.float_data: |
| 46 | + int_list = _npfloat16_to_int(np.float16(tensor.float_data)) |
| 47 | + tensor.int32_data[:] = int_list |
| 48 | + tensor.float_data[:] = [] |
| 49 | + # convert raw_data (bytes type) |
| 50 | + if tensor.raw_data: |
| 51 | + # convert n.raw_data to float |
| 52 | + float32_list = np.fromstring(tensor.raw_data, dtype='float32') |
| 53 | + # convert float to float16 |
| 54 | + float16_list = np.float16(float32_list) |
| 55 | + # convert float16 to bytes and write back to raw_data |
| 56 | + tensor.raw_data = float16_list.tostring() |
| 57 | + return tensor |
| 58 | + |
| 59 | + |
| 60 | +def convert_float_to_float16(model): |
| 61 | + ''' |
| 62 | + Convert tensor float type in the ONNX ModelProto input to tensor float16. |
| 63 | +
|
| 64 | + :param model: ONNX ModelProto object |
| 65 | + :return: converted ONNX ModelProto object |
| 66 | +
|
| 67 | + Examples: |
| 68 | +
|
| 69 | + :: |
| 70 | +
|
| 71 | + Example 1: Convert ONNX ModelProto object: |
| 72 | + from onnxmltools.utils.float16_converter import convert_float_to_float16 |
| 73 | + new_onnx_model = convert_float_to_float16(onnx_model) |
| 74 | +
|
| 75 | + Example 2: Convert ONNX model binary file: |
| 76 | + from onnxmltools.utils.float16_converter import convert_float_to_float16 |
| 77 | + from onnxmltools.utils import load_model, save_model |
| 78 | + onnx_model = load_model('model.onnx') |
| 79 | + new_onnx_model = convert_float_to_float16(onnx_model) |
| 80 | + save_model(new_onnx_model, 'new_model.onnx') |
| 81 | +
|
| 82 | + ''' |
| 83 | + func_infer_shape = None |
| 84 | + if onnx.__version__ >= '1.2': |
| 85 | + try: |
| 86 | + from onnx.shape_inference import infer_shapes |
| 87 | + func_infer_shape = infer_shapes |
| 88 | + finally: |
| 89 | + pass |
| 90 | + |
| 91 | + domain_flag = 0 |
| 92 | + if not isinstance(model, onnx_proto.ModelProto): |
| 93 | + raise ValueError('Expected model type is an ONNX ModelProto but got %s' % type(model)) |
| 94 | + |
| 95 | + # create black list |
| 96 | + op_black_list = ['ArrayFeatureExtractor', 'Binarizer', 'CastMap', 'CategoryMapper', 'DictVectorizer', |
| 97 | + 'FeatureVectorizer', 'Imputer', 'LabelEncoder', 'LinearClassifier', 'LinearRegressor', 'Normalizer', |
| 98 | + 'OneHotEncoder', 'SVMClassifier', 'SVMRegressor', 'Scaler', 'TreeEnsembleClassifier', |
| 99 | + 'TreeEnsembleRegressor', 'ZipMap'] |
| 100 | + # create a queue for BFS |
| 101 | + queue = [] |
| 102 | + value_info_list = [] |
| 103 | + node_list = [] |
| 104 | + # type inference on input model |
| 105 | + if func_infer_shape is not None: |
| 106 | + model = func_infer_shape(model) |
| 107 | + queue.append(model) |
| 108 | + while queue: |
| 109 | + next_level = [] |
| 110 | + for q in queue: |
| 111 | + # if q is model, push q.graph (GraphProto) |
| 112 | + if isinstance(q, onnx_proto.ModelProto): |
| 113 | + next_level.append(q.graph) |
| 114 | + # if q is model.graph, push q.node.attribute (AttributeProto) |
| 115 | + if isinstance(q, onnx_proto.GraphProto): |
| 116 | + for n in q.node: |
| 117 | + # if n is in the black list (doesn't support float16), no conversion for the node, |
| 118 | + # and save the node for further processing |
| 119 | + if n.op_type in op_black_list: |
| 120 | + node_list.append(n) |
| 121 | + else: |
| 122 | + next_level.append(n.attribute) |
| 123 | + # if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto) |
| 124 | + if isinstance(q, onnx_proto.AttributeProto): |
| 125 | + next_level.append(q.g) |
| 126 | + for n in q.graphs: |
| 127 | + next_level.append(n) |
| 128 | + # if q is graph, process graph.initializer(TensorProto), input, output and value_info (ValueInfoProto) |
| 129 | + if isinstance(q, onnx_proto.GraphProto): |
| 130 | + for n in q.initializer: # TensorProto type |
| 131 | + n = convert_tensor_float_to_float16(n) |
| 132 | + # for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to |
| 133 | + # tensor(float16) except map and seq(map). And save them in value_info_list for further processing |
| 134 | + for n in itertools.chain(q.input, q.output, q.value_info): |
| 135 | + if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: |
| 136 | + n.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 |
| 137 | + value_info_list.append(n) |
| 138 | + # if q is node.attribute, process node.attribute.t and node.attribute.tensors (TensorProto) |
| 139 | + if isinstance(q, onnx_proto.AttributeProto): |
| 140 | + for n in itertools.chain(q.t, q.tensors): |
| 141 | + n = convert_tensor_float_to_float16(n) |
| 142 | + queue = next_level |
| 143 | + |
| 144 | + # process the nodes in black list that doesn't support tensor(float16) |
| 145 | + for node in node_list: |
| 146 | + # if input's name is in the value_info_list meaning input is tensor(float16) type, insert a Cast node |
| 147 | + # before the node, change current node's input name and create new value_info for the new name |
| 148 | + for i in range(len(node.input)): |
| 149 | + input = node.input[i] |
| 150 | + for value_info in value_info_list: |
| 151 | + if input == value_info.name: |
| 152 | + # create new value_info for current node's new input name |
| 153 | + new_value_info = model.graph.value_info.add() |
| 154 | + new_value_info.CopyFrom(value_info) |
| 155 | + new_value_info.name = input + '_casted' |
| 156 | + new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT |
| 157 | + # add Cast node (from tensor(float16) to tensor(float) before current node |
| 158 | + attrs = {'name': input + 'Cast'} |
| 159 | + attrs['to'] = onnx_proto.TensorProto.FLOAT |
| 160 | + nodes = [helper.make_node('Cast', input, input + '_casted', kwargs=attrs)] |
| 161 | + model.graph.node.extend(nodes) |
| 162 | + # change current node's input name |
| 163 | + node.input[i] = input + '_casted' |
| 164 | + domain_flag = 1 |
| 165 | + continue |
| 166 | + # if output's name is in the value_info_list meaning output is tensor(float16) type, insert a float16 to |
| 167 | + # float Cast node after the node, change current node's output name and create new value_info for the new name |
| 168 | + for i in range(len(node.output)): |
| 169 | + output = node.output[i] |
| 170 | + for value_info in value_info_list: |
| 171 | + if output == value_info.name: |
| 172 | + # create new value_info for current node's new output |
| 173 | + new_value_info = model.graph.value_info.add() |
| 174 | + new_value_info.CopyFrom(value_info) |
| 175 | + new_value_info.name = output + '_casted' |
| 176 | + new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT |
| 177 | + # add Cast node (from tensor(float) to tensor(float16) after current node |
| 178 | + attrs = {'name': output + 'Cast'} |
| 179 | + attrs['to'] = onnx_proto.TensorProto.FLOAT16 |
| 180 | + nodes = [helper.make_node('Cast', output + '_casted', output, kwarg=attrs)] |
| 181 | + model.graph.node.extend(nodes) |
| 182 | + # change current node's input name |
| 183 | + node.output[i] = output + '_casted' |
| 184 | + domain_flag = 1 |
| 185 | + continue |
| 186 | + if domain_flag: |
| 187 | + # Create operator set for cast node |
| 188 | + op_set = model.opset_import.add() |
| 189 | + op_set.domain = "" |
| 190 | + op_set.version = 7 |
| 191 | + return model |
0 commit comments