21
21
class InferenceTranspiler :
22
22
def transpile (self , program , scope , place ):
23
23
'''
24
- Transpile the program to a inference program by fused batch normalization.
24
+ Transpile the program. Support only fuse batch normalization now.
25
+
26
+ :param program: program to transpile
27
+ :type program: Program
28
+ :param scope: inference scope
29
+ :type scope: Scope
30
+ :param place: inference place
31
+ :type place: Place
32
+ '''
33
+ self .fuse_batch_norm (program , scope , place )
34
+
35
+ def fuse_batch_norm (self , program , scope , place ):
36
+ '''
37
+ Transpile the program by fused batch normalization.
25
38
26
39
The batch normalization followed the convolution or fully connected layer
27
40
can be integrated with them. Doing so will give us a forward acceleration,
@@ -57,8 +70,6 @@ def transpile(self, program, scope, place):
57
70
:type scope: Scope
58
71
:param place: inference place
59
72
:type place: Place
60
- :return: program by fused batch normalization
61
- :rtype: Program
62
73
'''
63
74
self .scope = scope
64
75
self .place = place
@@ -96,7 +107,7 @@ def transpile(self, program, scope, place):
96
107
# TODO(luotao): use clone() method to flush the program.desc in force,
97
108
# since some large program.desc will not be flushed immediately.
98
109
# And a better solution will be considered later.
99
- return program .clone ()
110
+ program = program .clone ()
100
111
101
112
# ====================== private transpiler functions =====================
102
113
def _insert_bias_op (self , index , current_op , bn_op ):
@@ -142,11 +153,25 @@ def _fuse_param(self, current_op, bn_op, bias_op, with_bias):
142
153
:type with_bias: Int
143
154
'''
144
155
145
- def _load_tensor (param_name ):
146
- return self .scope .find_var (param_name [0 ]).get_tensor ()
156
+ def _update_param (op , old_param_name , new_param ):
157
+ # For the sake of remaining the original variables the same as before,
158
+ # create new variables in scope to store the new parameters.
159
+ old_param_name = old_param_name [0 ]
160
+ old_var = self .block .vars [old_param_name ]
161
+ new_param_name = old_param_name + '_fuse_bn'
162
+ new_var = self .block .create_parameter (
163
+ name = new_param_name .encode ('ascii' ),
164
+ type = old_var .type ,
165
+ dtype = old_var .dtype ,
166
+ shape = old_var .shape )
167
+ op .rename_input (old_param_name , new_param_name )
168
+ self .scope .var (new_param_name )
169
+
170
+ tensor = self .scope .find_var (new_param_name ).get_tensor ()
171
+ tensor .set (np .array (new_param ), self .place )
147
172
148
173
def _load_param (param_name ):
149
- return np .array (_load_tensor (param_name ))
174
+ return np .array (self . scope . find_var (param_name [ 0 ]). get_tensor ( ))
150
175
151
176
bias_bn = _load_param (bn_op .input ("Bias" )) #Bias
152
177
scale_bn = _load_param (bn_op .input ("Scale" )) #Scale
@@ -155,8 +180,6 @@ def _load_param(param_name):
155
180
156
181
# TODO(luotao1): consider only conv2d now. fc would be delt later.
157
182
current_param = _load_param (current_op .input ("Filter" ))
158
- current_tensor = _load_tensor (current_op .input ("Filter" ))
159
-
160
183
std_bn = np .float32 (np .sqrt (np .add (var_bn , 1e-5 )))
161
184
tmp = np .float32 (np .divide (scale_bn , std_bn ))
162
185
@@ -167,17 +190,16 @@ def _load_param(param_name):
167
190
bias = np .zeros (bias_bn .shape )
168
191
bias = np .float32 (
169
192
np .add (np .multiply (np .subtract (bias , mean_bn ), tmp ), bias_bn ))
170
- bias_tensor = _load_tensor (bias_op .input ("Y" ))
171
- bias_tensor .set (bias , self .place )
172
193
173
194
# re-compute weight of conv2d
174
195
tmp = tmp .reshape (tmp .shape [0 ], - 1 )
175
196
dst_param = current_param .reshape ((tmp .shape [0 ], - 1 ))
176
197
dst_param = np .float32 (np .multiply (dst_param , tmp ))
177
198
dst_param = dst_param .reshape (current_param .shape )
178
199
179
- # set the updated parameters
180
- current_tensor .set (np .array (dst_param ), self .place )
200
+ # update parameters
201
+ _update_param (current_op , current_op .input ("Filter" ), dst_param )
202
+ _update_param (bias_op , bias_op .input ("Y" ), bias )
181
203
182
204
# collect the renamed input
183
205
self .input_map [bn_op .output ("Y" )[0 ]] = bias_op .output ("Out" )[0 ]
0 commit comments