@@ -18,9 +18,6 @@ def args_check(cls, node, **kwargs):
1818 num_directions = 2 if direction == "bidirectional" else 1
1919 if "clip" in node .attrs :
2020 exception .OP_UNSUPPORTED_EXCEPT ("GRU with clip" , "Tensorflow" )
21- if node .attrs .get ("linear_before_reset" , 0 ):
22- exception .OP_UNSUPPORTED_EXCEPT ("GRU with linear_before_reset" ,
23- "Tensorflow" )
2421 if "activations" in node .attrs :
2522 activations = list (map (lambda x : x .lower (), node .attrs ["activations" ]))
2623 if activations [0 ] != "sigmoid" :
@@ -63,11 +60,16 @@ def _custom_getter(cls,
6360 if names [- 2 ] == "gates" :
6461 new_w = tf .transpose (tf .concat ([w_r , w_z ], 0 ))
6562 new_r = tf .transpose (tf .concat ([r_r , r_z ], 0 ))
66- elif names [- 2 ] == "candidate" :
63+ elif names [- 2 ] == "candidate" or names [ - 3 ] == "candidate" :
6764 new_w = tf .transpose (w_h )
6865 new_r = tf .transpose (r_h )
69- kernel = tf .concat ([new_w , new_r ], 0 )
70- return kernel
66+ if names [- 2 ] == 'input_projection' :
67+ return new_w
68+ elif names [- 2 ] == 'hidden_projection' :
69+ return new_r
70+ else :
71+ return tf .concat ([new_w , new_r ], 0 )
72+
7173 if names [- 1 ] == "bias" :
7274 if len (node .inputs ) >= 4 :
7375 # onnx Wb[zrh], Rb[zrh]
@@ -81,10 +83,15 @@ def _custom_getter(cls,
8183 if names [- 2 ] == "gates" :
8284 w_b = tf .transpose (tf .concat ([w_b_r , w_b_z ], 0 ))
8385 r_b = tf .transpose (tf .concat ([r_b_r , r_b_z ], 0 ))
84- elif names [- 2 ] == "candidate" :
86+ elif names [- 2 ] == "candidate" or names [ - 3 ] == "candidate" :
8587 w_b = tf .transpose (w_b_h )
8688 r_b = tf .transpose (r_b_h )
87- return tf .add (w_b , r_b )
89+ if names [- 2 ] == 'input_projection' :
90+ return w_b
91+ elif names [- 2 ] == 'hidden_projection' :
92+ return r_b
93+ else :
94+ return tf .add (w_b , r_b )
8895 return getter (name , * args , ** kwargs )
8996 return getter (name , * args , ** kwargs )
9097
@@ -158,7 +165,12 @@ def _common(cls, node, **kwargs):
158165 rnn_kwargs ["time_major" ] = True
159166 rnn_kwargs ["dtype" ] = tf .float32
160167
161- outputs , states = cls .rnn (x , tf .nn .rnn_cell .GRUCell , cell_kwargs ,
168+ if node .attrs .get ("linear_before_reset" , 0 ):
169+ cell_class = tf .contrib .cudnn_rnn .CudnnCompatibleGRUCell
170+ else :
171+ cell_class = tf .nn .rnn_cell .GRUCell
172+
173+ outputs , states = cls .rnn (x , cell_class , cell_kwargs ,
162174 rnn_kwargs , tf_activations , direction )
163175
164176 if num_directions == 1 :
0 commit comments