@@ -192,14 +192,14 @@ def version_11(cls, ctx, node, **kwargs):
192
192
"input mode must be linear input"
193
193
)
194
194
num_dirs = 1 if node .attr ["direction" ].s == b"unidirectional" else 2
195
- num_layers = int (h_shape [0 ]/ num_dirs )
195
+ num_layers = int (h_shape [0 ] / num_dirs )
196
196
num_units = hidden_size = h_shape [2 ]
197
197
input_size = x_shape [2 ]
198
- w_shape = [num_layers * num_dirs , 3 * hidden_size , input_size ]
198
+ w_shape = [num_layers * num_dirs , 3 * hidden_size , input_size ]
199
199
w_shape_const = ctx .make_const (utils .make_name ("w_shape" ), np .array (w_shape , dtype = np .int64 ))
200
- r_shape = [num_layers * num_dirs , 3 * hidden_size , hidden_size ]
200
+ r_shape = [num_layers * num_dirs , 3 * hidden_size , hidden_size ]
201
201
r_shape_const = ctx .make_const (utils .make_name ("r_shape" ), np .array (r_shape , dtype = np .int64 ))
202
- b_shape = [num_layers * num_dirs , 6 * hidden_size ]
202
+ b_shape = [num_layers * num_dirs , 6 * hidden_size ]
203
203
b_shape_const = ctx .make_const (utils .make_name ("b_shape" ), np .array (b_shape , dtype = np .int64 ))
204
204
zero_const = ctx .make_const (utils .make_name ("zero" ), np .array ([0 ], dtype = np .int64 ))
205
205
w_end = np .prod (w_shape )
@@ -208,13 +208,15 @@ def version_11(cls, ctx, node, **kwargs):
208
208
r_end_const = ctx .make_const (utils .make_name ("r_end" ), np .array ([r_end ], dtype = np .int64 ))
209
209
b_end = r_end + np .prod (b_shape )
210
210
b_end_const = ctx .make_const (utils .make_name ("b_end" ), np .array ([b_end ], dtype = np .int64 ))
211
- def Name (nm ):
211
+
212
+ def name (nm ):
212
213
return node .name + "_" + nm
213
- ws = [Name ('W_' + str (i )) for i in range (num_layers * num_dirs )]
214
- rs = [Name ('R_' + str (i )) for i in range (num_layers * num_dirs )]
215
- bs = [Name ('B_' + str (i )) for i in range (num_layers * num_dirs )]
216
- hs = [Name ('H_' + str (i )) for i in range (num_layers * num_dirs )]
217
- yhs = [Name ('YH_' + str (i )) for i in range (num_layers * num_dirs )]
214
+
215
+ ws = [name ('W_' + str (i )) for i in range (num_layers * num_dirs )]
216
+ rs = [name ('R_' + str (i )) for i in range (num_layers * num_dirs )]
217
+ bs = [name ('B_' + str (i )) for i in range (num_layers * num_dirs )]
218
+ hs = [name ('H_' + str (i )) for i in range (num_layers * num_dirs )]
219
+ yhs = [name ('YH_' + str (i )) for i in range (num_layers * num_dirs )]
218
220
w_flattened = ctx .make_node ('Slice' , [p , zero_const .output [0 ], w_end_const .output [0 ]])
219
221
r_flattened = ctx .make_node ('Slice' , [p , w_end_const .output [0 ], r_end_const .output [0 ]])
220
222
b_flattened = ctx .make_node ('Slice' , [p , r_end_const .output [0 ], b_end_const .output [0 ]])
@@ -230,19 +232,21 @@ def Name(nm):
230
232
ctx .make_node ('Split' , [h ], outputs = hs )
231
233
xnf = xnb = x
232
234
for i in range (num_layers ):
233
- suffix = '_' + str (i * num_dirs )
234
- ctx .make_node ('GRU' , [xnf , Name ('W' + suffix ), Name ('R' + suffix ), Name ('B' + suffix ), '' , Name ('H' + suffix )],
235
- outputs = [Name ('Y' + suffix ), Name ('YH' + suffix )],
235
+ suffix = '_' + str (i * num_dirs )
236
+ ctx .make_node ('GRU' ,
237
+ [xnf , name ('W' + suffix ), name ('R' + suffix ), name ('B' + suffix ), '' , name ('H' + suffix )],
238
+ outputs = [name ('Y' + suffix ), name ('YH' + suffix )],
236
239
attr = {'direction' : 'forward' , 'hidden_size' : num_units })
237
- xnf = Name (x + suffix )
238
- ctx .make_node ('Squeeze' , [Name ('Y' + suffix )], outputs = [xnf ], attr = {'axes' : [1 ]})
240
+ xnf = name (x + suffix )
241
+ ctx .make_node ('Squeeze' , [name ('Y' + suffix )], outputs = [xnf ], attr = {'axes' : [1 ]})
239
242
if num_dirs == 2 :
240
- suffix = '_' + str (i * 2 + 1 )
241
- ctx .make_node ('GRU' , [xnb , Name ('W' + suffix ), Name ('R' + suffix ), Name ('B' + suffix ), '' , Name ('H' + suffix )],
242
- outputs = [Name ('Y' + suffix ), Name ('YH' + suffix )],
243
+ suffix = '_' + str (i * 2 + 1 )
244
+ ctx .make_node ('GRU' ,
245
+ [xnb , name ('W' + suffix ), name ('R' + suffix ), name ('B' + suffix ), '' , name ('H' + suffix )],
246
+ outputs = [name ('Y' + suffix ), name ('YH' + suffix )],
243
247
attr = {'direction' : 'reverse' , 'hidden_size' : num_units })
244
- xnb = Name (x + suffix )
245
- ctx .make_node ('Squeeze' , [Name ('Y' + suffix )], outputs = [xnb ], attr = {'axes' : [1 ]})
248
+ xnb = name (x + suffix )
249
+ ctx .make_node ('Squeeze' , [name ('Y' + suffix )], outputs = [xnb ], attr = {'axes' : [1 ]})
246
250
ctx .remove_node (node .name )
247
251
if num_dirs == 2 :
248
252
ctx .make_node ('Concat' , [xnf , xnb ], outputs = [node .output [0 ]], attr = {'axis' : - 1 })
0 commit comments