|
| 1 | +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +from framework import Program |
| 17 | +from executor import global_scope |
| 18 | +from . import core |
| 19 | + |
| 20 | + |
| 21 | +class InferenceTranspiler: |
| 22 | + def transpile(self, program, place, scope=None): |
| 23 | + ''' |
| 24 | + Transpile the program. Support only fuse batch normalization now. |
| 25 | +
|
| 26 | + :param program: program to transpile |
| 27 | + :type program: Program |
| 28 | + :param place: inference place |
| 29 | + :type place: Place |
| 30 | + :param scope: inference scope |
| 31 | + :type scope: Scope or None |
| 32 | + ''' |
| 33 | + if not isinstance(program, Program): |
| 34 | + raise TypeError("program should be as Program type") |
| 35 | + if not isinstance(place, core.CPUPlace) and not isinstance( |
| 36 | + place, core.CUDAPlace): |
| 37 | + raise TypeError("place should be as CPUPlace/CUDAPlace type") |
| 38 | + if scope is None: |
| 39 | + scope = global_scope() |
| 40 | + if not isinstance(scope, core.Scope): |
| 41 | + raise TypeError("scope should be as Scope type or None") |
| 42 | + self.fuse_batch_norm(program, place, scope) |
| 43 | + |
| 44 | + def fuse_batch_norm(self, program, place, scope): |
| 45 | + ''' |
| 46 | + Transpile the program by fused batch normalization. |
| 47 | + |
| 48 | + The batch normalization followed the convolution or fully connected layer |
| 49 | + can be integrated with them. Doing so will give us a forward acceleration, |
| 50 | + especially in environments like mobile or embedded. |
| 51 | + |
| 52 | + For input X: |
| 53 | + - Conv process: X = input * W + bias |
| 54 | + - Batch norm process: X' = (X - mean) / std |
| 55 | + - Scale Process: Y = a * X' + b |
| 56 | +
|
| 57 | + After fuse into one operation: |
| 58 | +
|
| 59 | + Y = (input * W + bias - mean) / std * a + b |
| 60 | + = input * a * W / std + ((bias - mean) / std * a + b) |
| 61 | +
|
| 62 | + The operator transformation is: |
| 63 | + - before: |
| 64 | + - conv->batch_norm->any_other_op (bias == 0) |
| 65 | + - conv->elementwise_add->batch_norm->any_other_op (bias != 0) |
| 66 | + - after: |
| 67 | + - conv->elementwise_add->any_other_op |
| 68 | + |
| 69 | + The transpile stages are: |
| 70 | + 1. insert elementwise_add op when bias == 0. |
| 71 | + 2. fuse the batch_norm's parameters to conv and elementwise_add operators. |
| 72 | + 3. remove batch_norm ops which are not used in any other ops. |
| 73 | + 4. adjust the input of any_other_op to be the output of elementwise_add operator. |
| 74 | + 5. remove unused variables. |
| 75 | +
|
| 76 | + :param program: program to transpile |
| 77 | + :type program: Program |
| 78 | + :param place: inference place |
| 79 | + :type place: Place |
| 80 | + :param scope: inference scope |
| 81 | + :type scope: Scope |
| 82 | + ''' |
| 83 | + self.scope = scope |
| 84 | + self.place = place |
| 85 | + self.block = program.block(0) |
| 86 | + self.input_map = {} # store the input names should be adjusted |
| 87 | + |
| 88 | + i = 0 |
| 89 | + while i < len(self.block.ops): |
| 90 | + current_op = self.block.ops[i] |
| 91 | + # TODO(luotao1): consider only conv2d now. fc would be delt later. |
| 92 | + if current_op.type in ['conv2d']: |
| 93 | + # TODO(luotao1): consider single chain network now. |
| 94 | + # For branch network, we counldn't use block.ops[i + 1] as |
| 95 | + # the judgment condition. |
| 96 | + next_op = self.block.ops[i + 1] |
| 97 | + # conv2d without bias |
| 98 | + if (next_op.type == 'batch_norm'): |
| 99 | + # insert bias op |
| 100 | + bias_op = self._insert_bias_op(i + 1, current_op, next_op) |
| 101 | + # fuse batch_norm |
| 102 | + self._fuse_param(current_op, next_op, bias_op, 0) |
| 103 | + # remove batch_norm_op |
| 104 | + self.block.remove_op(i + 2) |
| 105 | + i = i + 1 |
| 106 | + # conv2d with bias, the next_op.type is elementwise_add |
| 107 | + elif (next_op.type == 'elementwise_add'): |
| 108 | + next_next_op = self.block.ops[i + 2] |
| 109 | + if (next_next_op.type == 'batch_norm'): |
| 110 | + # fuse batch_norm |
| 111 | + self._fuse_param(current_op, next_next_op, next_op, 1) |
| 112 | + # remove batch_norm_op |
| 113 | + self.block.remove_op(i + 2) |
| 114 | + i = i + 1 |
| 115 | + i = i + 1 |
| 116 | + |
| 117 | + self._adjust_input() |
| 118 | + self._remove_unused_var() |
| 119 | + # TODO(luotao): use clone() method to flush the program.desc in force, |
| 120 | + # since some large program.desc will not be flushed immediately. |
| 121 | + # And a better solution will be considered later. |
| 122 | + program = program.clone() |
| 123 | + |
| 124 | + # ====================== private transpiler functions ===================== |
| 125 | + def _insert_bias_op(self, index, current_op, bn_op): |
| 126 | + ''' |
| 127 | + Construct elementwise_add operator for adding bias |
| 128 | + and insert it into program. |
| 129 | + |
| 130 | + :param index: insert location of bias_op |
| 131 | + :type index: Int |
| 132 | + :param current_op: current operator (conv or fc) |
| 133 | + :type current_op: Operator |
| 134 | + :param bn_op: batch norm operator |
| 135 | + :type bn_op: Operator |
| 136 | + :return: bias_op |
| 137 | + :rtype: Operator |
| 138 | + ''' |
| 139 | + # The input of bias_op is current_op's output and Bias of bn_op |
| 140 | + # The output of bias_op is bn_op's output |
| 141 | + x_var = self.block.var(current_op.output("Output")[0]) |
| 142 | + y_var = self.block.var(bn_op.input("Bias")[0]) |
| 143 | + out_var = self.block.var(bn_op.output("Y")[0]) |
| 144 | + |
| 145 | + bias_op = self.block.insert_op( |
| 146 | + index, |
| 147 | + type="elementwise_add", |
| 148 | + inputs={"X": x_var, |
| 149 | + "Y": y_var}, |
| 150 | + outputs={"Out": out_var}, |
| 151 | + attrs={"axis": 1}) # dim_start=1 |
| 152 | + return bias_op |
| 153 | + |
| 154 | + def _fuse_param(self, current_op, bn_op, bias_op, with_bias): |
| 155 | + ''' |
| 156 | + fuse the batch_norm_op' parameters to current_op (conv or fc) |
| 157 | + |
| 158 | + :param current_op: current operator (conv or fc) |
| 159 | + :type current_op: Operator |
| 160 | + :param bn_op: batch norm operator |
| 161 | + :type bn_op: Operator |
| 162 | + :param bias_op: elementwise_add operator for adding bias |
| 163 | + :type bias_op: Operator |
| 164 | + :param with_bias: If current operator has bias, with_bias = 1; otherwise 0. |
| 165 | + :type with_bias: Int |
| 166 | + ''' |
| 167 | + |
| 168 | + def _update_param(op, old_param_name, new_param): |
| 169 | + # For the sake of remaining the original variables the same as before, |
| 170 | + # create new variables in scope to store the new parameters. |
| 171 | + old_param_name = old_param_name[0] |
| 172 | + old_var = self.block.vars[old_param_name] |
| 173 | + new_param_name = old_param_name + '_fuse_bn' |
| 174 | + new_var = self.block.create_parameter( |
| 175 | + name=new_param_name.encode('ascii'), |
| 176 | + type=old_var.type, |
| 177 | + dtype=old_var.dtype, |
| 178 | + shape=old_var.shape) |
| 179 | + op.rename_input(old_param_name, new_param_name) |
| 180 | + self.scope.var(new_param_name) |
| 181 | + |
| 182 | + tensor = self.scope.find_var(new_param_name).get_tensor() |
| 183 | + tensor.set(np.array(new_param), self.place) |
| 184 | + |
| 185 | + def _load_param(param_name): |
| 186 | + return np.array(self.scope.find_var(param_name[0]).get_tensor()) |
| 187 | + |
| 188 | + bias_bn = _load_param(bn_op.input("Bias")) #Bias |
| 189 | + scale_bn = _load_param(bn_op.input("Scale")) #Scale |
| 190 | + mean_bn = _load_param(bn_op.input("Mean")) #Mean |
| 191 | + var_bn = _load_param(bn_op.input("Variance")) #Variance |
| 192 | + |
| 193 | + # TODO(luotao1): consider only conv2d now. fc would be delt later. |
| 194 | + current_param = _load_param(current_op.input("Filter")) |
| 195 | + std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5))) |
| 196 | + tmp = np.float32(np.divide(scale_bn, std_bn)) |
| 197 | + |
| 198 | + # add bias of batch_norm_op to conv2d |
| 199 | + if with_bias: |
| 200 | + bias = _load_param(bias_op.input("Y")) |
| 201 | + else: |
| 202 | + bias = np.zeros(bias_bn.shape) |
| 203 | + bias = np.float32( |
| 204 | + np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn)) |
| 205 | + |
| 206 | + # re-compute weight of conv2d |
| 207 | + tmp = tmp.reshape(tmp.shape[0], -1) |
| 208 | + dst_param = current_param.reshape((tmp.shape[0], -1)) |
| 209 | + dst_param = np.float32(np.multiply(dst_param, tmp)) |
| 210 | + dst_param = dst_param.reshape(current_param.shape) |
| 211 | + |
| 212 | + # update parameters |
| 213 | + _update_param(current_op, current_op.input("Filter"), dst_param) |
| 214 | + _update_param(bias_op, bias_op.input("Y"), bias) |
| 215 | + |
| 216 | + # collect the renamed input |
| 217 | + self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0] |
| 218 | + |
| 219 | + def _adjust_input(self): |
| 220 | + for i in range(len(self.block.ops)): |
| 221 | + current_op = self.block.ops[i] |
| 222 | + for input_arg in current_op.input_arg_names: |
| 223 | + if input_arg in self.input_map: |
| 224 | + current_op.rename_input(input_arg, |
| 225 | + self.input_map[input_arg]) |
| 226 | + |
| 227 | + def _remove_unused_var(self): |
| 228 | + ''' |
| 229 | + remove unused varibles in program |
| 230 | + ''' |
| 231 | + args = [] |
| 232 | + for i in range(len(self.block.ops)): |
| 233 | + current_op = self.block.ops[i] |
| 234 | + args += current_op.input_arg_names |
| 235 | + args += current_op.output_arg_names |
| 236 | + args = list(set(args)) # unique the input and output arguments |
| 237 | + |
| 238 | + for var in self.block.vars.keys(): |
| 239 | + if var not in args: |
| 240 | + self.block.remove_var(var) |
0 commit comments