@@ -174,11 +174,11 @@ def version_7(cls, ctx, node, **kwargs):
174
174
class CudnnRNN :
175
175
@classmethod
176
176
def version_11 (cls , ctx , node , ** kwargs ):
177
- X = node .input [0 ]
178
- X_shape = ctx .get_shape (X )
179
- H = node .input [1 ]
180
- H_shape = ctx .get_shape (H )
181
- P = node .input [3 ]
177
+ x = node .input [0 ]
178
+ x_shape = ctx .get_shape (x )
179
+ h = node .input [1 ]
180
+ h_shape = ctx .get_shape (h )
181
+ p = node .input [3 ]
182
182
utils .make_sure (
183
183
node .attr ["rnn_mode" ].s == b"gru" ,
184
184
"rnn mode other than gru are not supported yet"
@@ -192,9 +192,9 @@ def version_11(cls, ctx, node, **kwargs):
192
192
"input mode must be linear input"
193
193
)
194
194
num_dirs = 1 if node .attr ["direction" ].s == b"unidirectional" else 2
195
- num_layers = int (H_shape [0 ]/ num_dirs )
196
- num_units = hidden_size = H_shape [2 ]
197
- input_size = X_shape [2 ]
195
+ num_layers = int (h_shape [0 ]/ num_dirs )
196
+ num_units = hidden_size = h_shape [2 ]
197
+ input_size = x_shape [2 ]
198
198
w_shape = [num_layers * num_dirs , 3 * hidden_size , input_size ]
199
199
w_shape_const = ctx .make_const (utils .make_name ("w_shape" ), np .array (w_shape , dtype = np .int64 ))
200
200
r_shape = [num_layers * num_dirs , 3 * hidden_size , hidden_size ]
@@ -208,44 +208,44 @@ def version_11(cls, ctx, node, **kwargs):
208
208
r_end_const = ctx .make_const (utils .make_name ("r_end" ), np .array ([r_end ], dtype = np .int64 ))
209
209
b_end = r_end + np .prod (b_shape )
210
210
b_end_const = ctx .make_const (utils .make_name ("b_end" ), np .array ([b_end ], dtype = np .int64 ))
211
- def NM (nm ):
211
+ def Name (nm ):
212
212
return node .name + "_" + nm
213
- WS = [NM ('W_' + str (i )) for i in range (num_layers * num_dirs )]
214
- RS = [NM ('R_' + str (i )) for i in range (num_layers * num_dirs )]
215
- BS = [NM ('B_' + str (i )) for i in range (num_layers * num_dirs )]
216
- HS = [NM ('H_' + str (i )) for i in range (num_layers * num_dirs )]
217
- YHS = [NM ('YH_' + str (i )) for i in range (num_layers * num_dirs )]
218
- W_flattened = ctx .make_node ('Slice' , [P , zero_const .output [0 ], w_end_const .output [0 ]])
219
- R_flattened = ctx .make_node ('Slice' , [P , w_end_const .output [0 ], r_end_const .output [0 ]])
220
- B_flattened = ctx .make_node ('Slice' , [P , r_end_const .output [0 ], b_end_const .output [0 ]])
221
- W = utils .make_name ('W' )
222
- R = utils .make_name ('R' )
223
- B = utils .make_name ('B' )
224
- ctx .make_node ('Reshape' , [W_flattened .output [0 ], w_shape_const .output [0 ]], outputs = [W ])
225
- ctx .make_node ('Reshape' , [R_flattened .output [0 ], r_shape_const .output [0 ]], outputs = [R ])
226
- ctx .make_node ('Reshape' , [B_flattened .output [0 ], b_shape_const .output [0 ]], outputs = [B ])
227
- ctx .make_node ('Split' , [W ], outputs = WS )
228
- ctx .make_node ('Split' , [R ], outputs = RS )
229
- ctx .make_node ('Split' , [B ], outputs = BS )
230
- ctx .make_node ('Split' , [H ], outputs = HS )
231
- XNF = XNB = X
213
+ ws = [Name ('W_' + str (i )) for i in range (num_layers * num_dirs )]
214
+ rs = [Name ('R_' + str (i )) for i in range (num_layers * num_dirs )]
215
+ bs = [Name ('B_' + str (i )) for i in range (num_layers * num_dirs )]
216
+ hs = [Name ('H_' + str (i )) for i in range (num_layers * num_dirs )]
217
+ yhs = [Name ('YH_' + str (i )) for i in range (num_layers * num_dirs )]
218
+ w_flattened = ctx .make_node ('Slice' , [p , zero_const .output [0 ], w_end_const .output [0 ]])
219
+ r_flattened = ctx .make_node ('Slice' , [p , w_end_const .output [0 ], r_end_const .output [0 ]])
220
+ b_flattened = ctx .make_node ('Slice' , [p , r_end_const .output [0 ], b_end_const .output [0 ]])
221
+ w = utils .make_name ('W' )
222
+ r = utils .make_name ('R' )
223
+ b = utils .make_name ('B' )
224
+ ctx .make_node ('Reshape' , [w_flattened .output [0 ], w_shape_const .output [0 ]], outputs = [w ])
225
+ ctx .make_node ('Reshape' , [r_flattened .output [0 ], r_shape_const .output [0 ]], outputs = [r ])
226
+ ctx .make_node ('Reshape' , [b_flattened .output [0 ], b_shape_const .output [0 ]], outputs = [b ])
227
+ ctx .make_node ('Split' , [w ], outputs = ws )
228
+ ctx .make_node ('Split' , [r ], outputs = rs )
229
+ ctx .make_node ('Split' , [b ], outputs = bs )
230
+ ctx .make_node ('Split' , [h ], outputs = hs )
231
+ xnf = xnb = x
232
232
for i in range (num_layers ):
233
233
suffix = '_' + str (i * num_dirs )
234
- ctx .make_node ('GRU' , [XNF , NM ('W' + suffix ), NM ('R' + suffix ), NM ('B' + suffix ), '' , NM ('H' + suffix )],
235
- outputs = [NM ('Y' + suffix ), NM ('YH' + suffix )],
234
+ ctx .make_node ('GRU' , [xnf , Name ('W' + suffix ), Name ('R' + suffix ), Name ('B' + suffix ), '' , Name ('H' + suffix )],
235
+ outputs = [Name ('Y' + suffix ), Name ('YH' + suffix )],
236
236
attr = {'direction' : 'forward' , 'hidden_size' : num_units })
237
- XNF = NM ( X + suffix )
238
- ctx .make_node ('Squeeze' , [NM ('Y' + suffix )], outputs = [XNF ], attr = {'axes' : [1 ]})
237
+ xnf = Name ( x + suffix )
238
+ ctx .make_node ('Squeeze' , [Name ('Y' + suffix )], outputs = [xnf ], attr = {'axes' : [1 ]})
239
239
if num_dirs == 2 :
240
240
suffix = '_' + str (i * 2 + 1 )
241
- ctx .make_node ('GRU' , [XNB , NM ('W' + suffix ), NM ('R' + suffix ), NM ('B' + suffix ), '' , NM ('H' + suffix )],
242
- outputs = [NM ('Y' + suffix ), NM ('YH' + suffix )],
241
+ ctx .make_node ('GRU' , [xnb , Name ('W' + suffix ), Name ('R' + suffix ), Name ('B' + suffix ), '' , Name ('H' + suffix )],
242
+ outputs = [Name ('Y' + suffix ), Name ('YH' + suffix )],
243
243
attr = {'direction' : 'reverse' , 'hidden_size' : num_units })
244
- XNB = NM ( X + suffix )
245
- ctx .make_node ('Squeeze' , [NM ('Y' + suffix )], outputs = [XNB ], attr = {'axes' : [1 ]})
244
+ xnb = Name ( x + suffix )
245
+ ctx .make_node ('Squeeze' , [Name ('Y' + suffix )], outputs = [xnb ], attr = {'axes' : [1 ]})
246
246
ctx .remove_node (node .name )
247
247
if num_dirs == 2 :
248
- ctx .make_node ('Concat' , [XNF , XNB ], outputs = [node .output [0 ]], attr = {'axis' : - 1 })
248
+ ctx .make_node ('Concat' , [xnf , xnb ], outputs = [node .output [0 ]], attr = {'axis' : - 1 })
249
249
else :
250
- ctx .make_node ('Identity' , [XNF ], outputs = [node .output [0 ]])
251
- ctx .make_node ('Concat' , YHS , outputs = [node .output [1 ]], attr = {'axis' : 0 })
250
+ ctx .make_node ('Identity' , [xnf ], outputs = [node .output [0 ]])
251
+ ctx .make_node ('Concat' , yhs , outputs = [node .output [1 ]], attr = {'axis' : 0 })
0 commit comments