14
14
15
15
import collections
16
16
import numpy as np
17
- import six
18
17
from ..... import compat as cpt
19
18
from .... import core
20
- from .... import Executor
21
19
from ....framework import IrGraph
22
20
from ....framework import IrNode
23
- from ....framework import Program
24
- from ....initializer import Constant
25
- from ....initializer import NumpyArrayInitializer
26
21
from .... import unique_name
27
22
28
23
__all__ = [
@@ -107,7 +102,6 @@ def __init__(self,
107
102
self ._window_size = window_size
108
103
self ._moving_rate = moving_rate
109
104
110
- self ._need_initialized = collections .OrderedDict ()
111
105
self ._quantizable_ops = ['conv2d' , 'depthwise_conv2d' , 'mul' ]
112
106
self ._conv_ops = ['conv2d' , 'depthwise_conv2d' ]
113
107
self ._quantizable_grad_ops = [
@@ -127,14 +121,17 @@ def apply(self, graph):
127
121
"""
128
122
assert isinstance (graph ,
129
123
IrGraph ), 'graph must be the instance of IrGraph.'
130
- self ._need_initialized .clear ()
124
+ #sequential_execution = core.get_pass('sequential_execution_pass')
125
+ #sequential_execution.apply(graph.graph)
131
126
self ._is_test = graph .is_test ()
132
127
# marked the variable which has been dequantized.
133
128
dequantized_vars = collections .OrderedDict ()
134
129
persistable_vars = [p .name () for p in graph .all_persistable_nodes ()]
135
130
136
131
def _transform_forward (graph , op ):
137
132
for var_node in op .inputs :
133
+ if var_node .name () not in op .input_arg_names ():
134
+ continue
138
135
if var_node .name () in dequantized_vars :
139
136
dequant_var_node = dequantized_vars [var_node .name ()]
140
137
else :
@@ -168,6 +165,8 @@ def _transform_forward(graph, op):
168
165
def _transform_backward (graph , op ):
169
166
no_dequanted_input_vars = True
170
167
for var_node in op .inputs :
168
+ if var_node .name () not in op .input_arg_names ():
169
+ continue
171
170
if var_node .name () in dequantized_vars :
172
171
dequant_var_node = dequantized_vars [var_node .name ()]
173
172
graph .update_input_link (var_node , dequant_var_node , op )
@@ -188,25 +187,7 @@ def _transform_backward(graph, op):
188
187
for op in ops :
189
188
if op .name () in self ._quantizable_grad_ops :
190
189
_transform_backward (graph , op )
191
-
192
- if len (self ._need_initialized ) > 0 :
193
- assert self ._scope is not None , \
194
- 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
195
- assert self ._place is not None , \
196
- 'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
197
- init_program = Program ()
198
- for var_desc , initializer in six .iteritems (self ._need_initialized ):
199
- var = init_program .global_block ().create_var (
200
- name = var_desc .name (),
201
- shape = var_desc .shape (),
202
- dtype = var_desc .dtype (),
203
- type = var_desc .type (),
204
- lod_level = var_desc .lod_level (),
205
- persistable = var_desc .persistable ())
206
- initializer (var , init_program .global_block ())
207
- exe = Executor (self ._place )
208
- exe .run (program = init_program , scope = self ._scope )
209
-
190
+ graph .resolve_hazard ()
210
191
return graph
211
192
212
193
def _create_global_step (self , graph ):
@@ -222,8 +203,9 @@ def _create_global_step(self, graph):
222
203
var_type = core .VarDesc .VarType .LOD_TENSOR ,
223
204
shape = [1 ],
224
205
var_dtype = core .VarDesc .VarType .INT64 )
225
- self ._need_initialized [global_step_in .var ()] = \
226
- Constant (value = 0 , force_cpu = True )
206
+ self ._init_var_node (
207
+ global_step_in , np .zeros (
208
+ [1 ], dtype = 'int64' ))
227
209
global_step_out = graph .create_var_node_from_desc (
228
210
global_step_in .var ())
229
211
# The attribute of `op_role` is needed by ParallelExecutor.
@@ -300,7 +282,9 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
300
282
var_type = core .VarDesc .VarType .LOD_TENSOR ,
301
283
shape = [1 ],
302
284
var_dtype = var_node .dtype ())
303
- self ._need_initialized [scale_in_node .var ()] = Constant (value = 0.001 )
285
+ data_type = 'float64' if var_node .dtype (
286
+ ) == core .VarDesc .VarType .FP64 else 'float32'
287
+ self ._init_var_node (scale_in_node , np .array ([0.001 ], dtype = data_type ))
304
288
305
289
scale_out_node = graph .create_var_node_from_desc (scale_in_node .var ())
306
290
inputs = {'X' : var_node , 'InScale' : scale_in_node }
@@ -313,7 +297,11 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
313
297
var_type = core .VarDesc .VarType .LOD_TENSOR ,
314
298
shape = [self ._window_size ],
315
299
var_dtype = var_node .dtype ())
316
- self ._need_initialized [scales_node .var ()] = Constant (value = 0 )
300
+ data_type = 'float64' if var_node .dtype (
301
+ ) == core .VarDesc .VarType .FP64 else 'float32'
302
+ self ._init_var_node (
303
+ scales_node , np .zeros (
304
+ [self ._window_size ], dtype = data_type ))
317
305
inputs ['Iter' ] = self ._global_step
318
306
outputs ['OutScales' ] = scales_node
319
307
attrs = {
@@ -353,7 +341,9 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
353
341
var_type = core .VarDesc .VarType .LOD_TENSOR ,
354
342
shape = [1 ],
355
343
var_dtype = var_node .dtype ())
356
- self ._need_initialized [scale_in_node .var ()] = Constant (value = 0.001 )
344
+ data_type = 'float64' if var_node .dtype (
345
+ ) == core .VarDesc .VarType .FP64 else 'float32'
346
+ self ._init_var_node (scale_in_node , np .array ([0.001 ], dtype = data_type ))
357
347
358
348
scale_out_node = graph .create_var_node_from_desc (scale_in_node .var ())
359
349
ins = {'X' : var_node , 'InScale' : scale_in_node }
@@ -364,13 +354,15 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
364
354
var_type = core .VarDesc .VarType .LOD_TENSOR ,
365
355
var_dtype = var_node .dtype (),
366
356
shape = [1 ])
367
- self ._need_initialized [state_in_node .var ()] = Constant (value = 1 )
357
+ data_type = 'float64' if var_node .dtype (
358
+ ) == core .VarDesc .VarType .FP64 else 'float32'
359
+ self ._init_var_node (scale_in_node , np .ones ([1 ], dtype = data_type ))
368
360
accum_in_node = graph .create_persistable_node (
369
361
name = unique_name .generate ('accum' ),
370
362
var_type = core .VarDesc .VarType .LOD_TENSOR ,
371
363
var_dtype = var_node .dtype (),
372
364
shape = [1 ])
373
- self ._need_initialized [ accum_in_node . var ()] = Constant ( value = 1 )
365
+ self ._init_var_node ( accum_in_node , np . ones ([ 1 ], dtype = data_type ) )
374
366
state_out_node = graph .create_var_node_from_desc (state_in_node .var (
375
367
))
376
368
accum_out_node = graph .create_var_node_from_desc (accum_in_node .var (
@@ -490,6 +482,16 @@ def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
490
482
graph .link_to (dequant_op_node , dequant_var_node )
491
483
return dequant_var_node
492
484
485
+ def _init_var_node (self , var_node , value ):
486
+ assert isinstance (
487
+ value , np .ndarray ), 'The type of value should be numpy array.'
488
+ assert self ._scope is not None , \
489
+ 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
490
+ assert self ._place is not None , \
491
+ 'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
492
+ tensor = self ._scope .var (var_node .name ()).get_tensor ()
493
+ tensor .set (value , self ._place )
494
+
493
495
def _quantized_var_name (self , var_name ):
494
496
"""
495
497
Return quantized variable name for the input `var_name`.
@@ -592,7 +594,8 @@ def apply(self, graph):
592
594
self ._weight_bits )
593
595
self ._restore_var (input_arg_name , quantized_param_v )
594
596
else :
595
- scale_v = graph .var_node (op_node .output ('OutScale' )[0 ])
597
+ scale_v = self ._to_node (op_node .outputs ,
598
+ op_node .output ('OutScale' )[0 ])
596
599
self ._var_scale_map [input_arg_name ] = scale_v
597
600
598
601
ops = graph .all_op_nodes ()
@@ -613,32 +616,35 @@ def apply(self, graph):
613
616
for op_node in ops :
614
617
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
615
618
for var_node in op_node .inputs :
616
- name = var_node .name ()
617
- if name in self ._op_output_rename_map :
618
- old_in = graph .var_node (name )
619
- new_in = self ._op_output_rename_map [name ]
619
+ if var_node .node in self ._op_output_rename_map :
620
+ old_in = var_node
621
+ new_in = self ._op_output_rename_map [var_node .node ]
620
622
graph .update_input_link (old_in , new_in , op_node )
621
623
622
624
# remove the unused var node in the graph
623
625
self ._remove_unused_var_nodes (graph )
626
+ graph .resolve_hazard ()
624
627
return graph
625
628
626
629
def _remove_fake_quant_and_dequant_op (self , graph , op_node ):
627
- k = op_node .output ('Out' )[0 ]
628
- v = op_node .input ('X' )[0 ]
629
- if v not in self ._op_input_rename_map :
630
- self ._op_input_rename_map [k ] = v
630
+ k = self . _to_node ( op_node .outputs , op_node . output ('Out' )[0 ])
631
+ v = self . _to_node ( op_node .inputs , op_node . input ('X' )[0 ])
632
+ if v . node not in self ._op_input_rename_map :
633
+ self ._op_input_rename_map [k . node ] = v
631
634
else :
632
- self ._op_input_rename_map [k ] = self ._op_input_rename_map [v ]
635
+ self ._op_input_rename_map [k .node ] = self ._op_input_rename_map [
636
+ v .node ]
633
637
graph .safe_remove_nodes (op_node )
634
638
635
639
def _insert_post_channel_dequant_op (self , graph , op_node ):
636
640
persistable_vars = [p .name () for p in graph .all_persistable_nodes ()]
637
641
for var_node in op_node .inputs :
638
642
name = var_node .name ()
639
- if name in self ._op_input_rename_map :
640
- old_in = graph .var_node (name )
641
- new_in = graph .var_node (self ._op_input_rename_map [name ])
643
+ if name not in op_node .input_arg_names ():
644
+ continue
645
+ if var_node .node in self ._op_input_rename_map :
646
+ old_in = var_node
647
+ new_in = self ._op_input_rename_map [var_node .node ]
642
648
new_in .clear_outputs ()
643
649
graph .update_input_link (old_in , new_in , op_node )
644
650
original_var_name = self ._original_var_name (name )
@@ -653,28 +659,20 @@ def _insert_post_channel_dequant_op(self, graph, op_node):
653
659
assert isinstance (scale_v , IrNode )
654
660
scale_var_node = self ._var_scale_map [original_var_name ]
655
661
656
- if len (op_node .outputs ) != 1 :
662
+ if len (op_node .output_arg_names () ) != 1 :
657
663
raise ValueError ("Only support one output, but op %s has"
658
664
" more than one output." % (op_node .name ()))
659
665
660
- output_var_node = op_node .outputs [0 ]
666
+ output_var_node = self ._to_node (op_node .outputs ,
667
+ op_node .output_arg_names ()[0 ])
661
668
weight_scale_node = graph .create_persistable_node (
662
669
name = unique_name .generate ('channel_scale' ),
663
670
var_type = core .VarDesc .VarType .LOD_TENSOR ,
664
671
shape = [channel_scale .shape [0 ]],
665
672
var_dtype = output_var_node .dtype ())
666
- init_program = Program ()
667
- weight_scale_var = init_program .global_block ().create_var (
668
- name = weight_scale_node .name (),
669
- shape = weight_scale_node .shape (),
670
- dtype = weight_scale_node .dtype (),
671
- type = weight_scale_node .type (),
672
- lod_level = weight_scale_node .var ().lod_level (),
673
- persistable = weight_scale_node .persistable ())
674
- initializer = NumpyArrayInitializer (value = channel_scale )
675
- initializer (weight_scale_var , init_program .global_block ())
676
- exe = Executor (self ._place )
677
- exe .run (program = init_program , scope = self ._scope )
673
+ data_type = 'float64' if output_var_node .dtype (
674
+ ) == core .VarDesc .VarType .FP64 else 'float32'
675
+ self ._init_var_node (weight_scale_node , channel_scale .astype (data_type ))
678
676
dequant_var_node = graph .create_var_node (
679
677
name = self ._dequantized_var_name (output_var_node .name ()),
680
678
var_type = output_var_node .type (),
@@ -695,16 +693,18 @@ def _insert_post_channel_dequant_op(self, graph, op_node):
695
693
graph .link_to (scale_var_node , dequant_op_node )
696
694
graph .link_to (weight_scale_node , dequant_op_node )
697
695
graph .link_to (dequant_op_node , dequant_var_node )
698
- self ._op_output_rename_map [output_var_node .name () ] = dequant_var_node
696
+ self ._op_output_rename_map [output_var_node .node ] = dequant_var_node
699
697
return dequant_var_node
700
698
701
699
def _insert_post_dequant_op (self , graph , op_node ):
702
700
persistable_vars = [p .name () for p in graph .all_persistable_nodes ()]
703
701
for var_node in op_node .inputs :
704
702
name = var_node .name ()
705
- if name in self ._op_input_rename_map :
706
- old_in = graph .var_node (name )
707
- new_in = graph .var_node (self ._op_input_rename_map [name ])
703
+ if name not in op_node .input_arg_names ():
704
+ continue
705
+ if var_node .node in self ._op_input_rename_map :
706
+ old_in = var_node
707
+ new_in = self ._op_input_rename_map [var_node .node ]
708
708
new_in .clear_outputs ()
709
709
graph .update_input_link (old_in , new_in , op_node )
710
710
original_var_name = self ._original_var_name (name )
@@ -720,11 +720,12 @@ def _insert_post_dequant_op(self, graph, op_node):
720
720
assert isinstance (scale_v , IrNode )
721
721
scale_var_node = self ._var_scale_map [original_var_name ]
722
722
723
- if len (op_node .outputs ) != 1 :
723
+ if len (op_node .output_arg_names () ) != 1 :
724
724
raise ValueError ("Only support one output, but op %s has"
725
725
" more than one output." % (op_node .name ()))
726
726
727
- output_var_node = op_node .outputs [0 ]
727
+ output_var_node = self ._to_node (op_node .outputs ,
728
+ op_node .output_arg_names ()[0 ])
728
729
dequant_var_node = graph .create_var_node (
729
730
name = self ._dequantized_var_name (output_var_node .name ()),
730
731
var_type = output_var_node .type (),
@@ -742,9 +743,27 @@ def _insert_post_dequant_op(self, graph, op_node):
742
743
graph .link_to (output_var_node , dequant_op_node )
743
744
graph .link_to (scale_var_node , dequant_op_node )
744
745
graph .link_to (dequant_op_node , dequant_var_node )
745
- self ._op_output_rename_map [output_var_node .name () ] = dequant_var_node
746
+ self ._op_output_rename_map [output_var_node .node ] = dequant_var_node
746
747
return dequant_var_node
747
748
749
+ def _init_var_node (self , var_node , value ):
750
+ assert isinstance (
751
+ value , np .ndarray ), 'The type of value should be numpy array.'
752
+ assert self ._scope is not None , \
753
+ 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
754
+ assert self ._place is not None , \
755
+ 'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
756
+ tensor = self ._scope .var (var_node .name ()).get_tensor ()
757
+ tensor .set (value , self ._place )
758
+
759
+ def _to_node (self , nodes , node_name ):
760
+ target_node = None
761
+ for n in nodes :
762
+ if n .name () == node_name :
763
+ target_node = n
764
+ assert target_node is not None , "Cannot find the target node in the giving set."
765
+ return target_node
766
+
748
767
def _load_var (self , name ):
749
768
return np .array (self ._scope .find_var (name ).get_tensor ())
750
769
@@ -848,6 +867,7 @@ def apply(self, graph):
848
867
849
868
# remove the unused var node in the graph
850
869
self ._remove_unused_var_nodes (graph )
870
+ graph .resolve_hazard ()
851
871
return graph
852
872
853
873
def _convert_to_int8 (self , graph , var_node ):
@@ -930,5 +950,5 @@ def apply(self, graph):
930
950
for output_node in op_node .outputs :
931
951
graph .link_to (dequant_node , output_node )
932
952
graph .safe_remove_nodes (op_node )
933
-
953
+ graph . resolve_hazard ()
934
954
return graph
0 commit comments