@@ -699,8 +699,8 @@ def dynamic_gru(input,
699
699
def gru_unit (input ,
700
700
hidden ,
701
701
size ,
702
- weight = None ,
703
- bias = None ,
702
+ param_attr = None ,
703
+ bias_attr = None ,
704
704
activation = 'tanh' ,
705
705
gate_activation = 'sigmoid' ):
706
706
"""
@@ -731,8 +731,8 @@ def gru_unit(input,
731
731
input (Variable): The fc transformed input value of current step.
732
732
hidden (Variable): The hidden value of lstm unit from previous step.
733
733
size (integer): The input dimension value.
734
- weight (ParamAttr): The weight parameters for gru unit. Default: None
735
- bias (ParamAttr): The bias parameters for gru unit. Default: None
734
+ param_attr (ParamAttr): The weight parameters for gru unit. Default: None
735
+ bias_attr (ParamAttr): The bias parameters for gru unit. Default: None
736
736
activation (string): The activation type for cell (actNode).
737
737
Default: 'tanh'
738
738
gate_activation (string): The activation type for gates (actGate).
@@ -764,34 +764,31 @@ def gru_unit(input,
764
764
size = size / 3
765
765
766
766
# create weight
767
- if weight is None :
768
- weight = helper .create_parameter (
769
- attr = helper .param_attr , shape = [size , 3 * size ], dtype = dtype )
767
+ weight = helper .create_parameter (
768
+ attr = helper .param_attr , shape = [size , 3 * size ], dtype = dtype )
770
769
770
+ gate = helper .create_tmp_variable (dtype )
771
+ reset_hidden_pre = helper .create_tmp_variable (dtype )
772
+ updated_hidden = helper .create_tmp_variable (dtype )
773
+ inputs = {'Input' : input , 'HiddenPrev' : hidden , 'Weight' : weight }
771
774
# create bias
772
-
773
- if bias is None :
775
+ if helper .bias_attr :
774
776
bias_size = [1 , 3 * size ]
775
777
bias = helper .create_parameter (
776
778
attr = helper .bias_attr , shape = bias_size , dtype = dtype , is_bias = True )
777
-
778
- gate = helper .create_tmp_variable (dtype )
779
- reset_hidden_pre = helper .create_tmp_variable (dtype )
780
- updated_hidden = helper .create_tmp_variable (dtype )
779
+ inputs ['Bias' ] = bias
781
780
782
781
helper .append_op (
783
782
type = 'gru_unit' ,
784
- inputs = {'Input' : input ,
785
- 'HiddenPrev' : hidden ,
786
- 'Weight' : weight },
783
+ inputs = inputs ,
787
784
outputs = {
788
785
'Gate' : gate ,
789
786
'ResetHiddenPrev' : reset_hidden_pre ,
790
787
'Hidden' : updated_hidden ,
791
788
},
792
789
attrs = {
793
- 'activation' : 0 ,
794
- 'gate_activation' : 1 ,
790
+ 'activation' : 2 , # tanh
791
+ 'gate_activation' : 1 , # sigmoid
795
792
})
796
793
797
794
return updated_hidden , reset_hidden_pre , gate
0 commit comments