@@ -175,12 +175,6 @@ def version_7(cls, ctx, node, **kwargs):
175
175
class CudnnRNN :
176
176
@classmethod
177
177
def version_11 (cls , ctx , node , ** kwargs ):
178
- #print ("CudnnRNN captured")
179
- #print (node.attr["direction"].s)
180
- #ii = input(os.getpid())
181
- #print ("input\n", node.input)
182
- #print ("\noutput\n", node.output)
183
- #print ("\nattr\n", node.attr)
184
178
X = node .input [0 ]
185
179
X_shape = ctx .get_shape (X )
186
180
H = node .input [1 ]
@@ -190,6 +184,14 @@ def version_11(cls, ctx, node, **kwargs):
190
184
node .attr ["rnn_mode" ].s == b"gru" ,
191
185
"rnn mode other than gru are not supported yet"
192
186
)
187
+ utils .make_sure (
188
+ node .attr ["dropout" ].f == 0 ,
189
+ "dropout not supported yet"
190
+ )
191
+ utils .make_sure (
192
+ node .attr ["input_mode" ].s == b"linear_input" ,
193
+ "input mode must be linear input"
194
+ )
193
195
num_dirs = 1 if node .attr ["direction" ].s == b"unidirectional" else 2
194
196
num_layers = int (H_shape [0 ]/ num_dirs )
195
197
num_units = hidden_size = H_shape [2 ]
@@ -217,26 +219,24 @@ def NM(nm):
217
219
W_flattened = ctx .make_node ('Slice' , [P , zero_const .output [0 ], w_end_const .output [0 ]])
218
220
R_flattened = ctx .make_node ('Slice' , [P , w_end_const .output [0 ], r_end_const .output [0 ]])
219
221
B_flattened = ctx .make_node ('Slice' , [P , r_end_const .output [0 ], b_end_const .output [0 ]])
220
- # W = utils.make_name('W')
221
- # R = utils.make_name('R')
222
- # B = utils.make_name('B')
223
- W = ctx .make_node ('Reshape' , [W_flattened .output [0 ], w_shape_const .output [0 ]])
224
- R = ctx .make_node ('Reshape' , [R_flattened .output [0 ], r_shape_const .output [0 ]])
225
- B = ctx .make_node ('Reshape' , [B_flattened .output [0 ], b_shape_const .output [0 ]])
226
- ctx .make_node ('Split' , [W . output [ 0 ] ], outputs = WS )
227
- ctx .make_node ('Split' , [R . output [ 0 ] ], outputs = RS )
228
- ctx .make_node ('Split' , [B . output [ 0 ] ], outputs = BS )
222
+ W = utils .make_name ('W' )
223
+ R = utils .make_name ('R' )
224
+ B = utils .make_name ('B' )
225
+ ctx .make_node ('Reshape' , [W_flattened .output [0 ], w_shape_const .output [0 ]], outputs = [ W ])
226
+ ctx .make_node ('Reshape' , [R_flattened .output [0 ], r_shape_const .output [0 ]], outputs = [ R ])
227
+ ctx .make_node ('Reshape' , [B_flattened .output [0 ], b_shape_const .output [0 ]], outputs = [ B ])
228
+ ctx .make_node ('Split' , [W ], outputs = WS )
229
+ ctx .make_node ('Split' , [R ], outputs = RS )
230
+ ctx .make_node ('Split' , [B ], outputs = BS )
229
231
ctx .make_node ('Split' , [H ], outputs = HS )
230
232
XNF = XNB = X
231
- gru_nodes = []
232
- squeeze_nodes = []
233
233
for i in range (num_layers ):
234
- suffix = '_' + str (i * 2 )
235
- gru_nodes . append ( ctx .make_node ('GRU' , [XNF , NM ('W' + suffix ), NM ('R' + suffix ), NM ('B' + suffix ), '' , NM ('H' + suffix )],
236
- outputs = [NM ('Y' + suffix ), NM ('YH' + suffix )],
237
- attr = {'direction' :'forward' , 'hidden_size' :num_units }) )
234
+ suffix = '_' + str (i * num_dirs )
235
+ ctx .make_node ('GRU' , [XNF , NM ('W' + suffix ), NM ('R' + suffix ), NM ('B' + suffix ), '' , NM ('H' + suffix )],
236
+ outputs = [NM ('Y' + suffix ), NM ('YH' + suffix )],
237
+ attr = {'direction' :'forward' , 'hidden_size' :num_units })
238
238
XNF = NM (X + suffix )
239
- squeeze_nodes . append ( ctx .make_node ('Squeeze' , [NM ('Y' + suffix )], outputs = [XNF ], attr = {'axes' : [1 ]}) )
239
+ ctx .make_node ('Squeeze' , [NM ('Y' + suffix )], outputs = [XNF ], attr = {'axes' : [1 ]})
240
240
if num_dirs == 2 :
241
241
suffix = '_' + str (i * 2 + 1 )
242
242
ctx .make_node ('GRU' , [XNB , NM ('W' + suffix ), NM ('R' + suffix ), NM ('B' + suffix ), '' , NM ('H' + suffix )],
@@ -249,5 +249,4 @@ def NM(nm):
249
249
ctx .make_node ('Concat' , [XNF , XNB ], outputs = [node .output [0 ]], attr = {'axis' : - 1 })
250
250
else :
251
251
identity_0 = ctx .make_node ('Identity' , [XNF ], outputs = [node .output [0 ]])
252
- concat_0 = ctx .make_node ('Concat' , YHS , outputs = [node .output [1 ]], attr = {'axis' : 0 })
253
- #print ("Done")
252
+ concat_0 = ctx .make_node ('Concat' , YHS , outputs = [node .output [1 ]], attr = {'axis' : 0 })
0 commit comments