@@ -261,9 +261,10 @@ def embedding(input,
261
261
return tmp
262
262
263
263
264
- # TODO(qijun): expose H0 and C0
265
264
def dynamic_lstm (input ,
266
265
size ,
266
+ h_0 = None ,
267
+ c_0 = None ,
267
268
param_attr = None ,
268
269
bias_attr = None ,
269
270
use_peepholes = True ,
@@ -324,6 +325,13 @@ def dynamic_lstm(input,
324
325
(T X 4D), where T is the total time steps in this
325
326
mini-batch, D is the hidden size.
326
327
size(int): 4 * hidden size.
328
+ h_0(Variable): The initial hidden state is an optional input, default is zero.
329
+ This is a tensor with shape (N x D), where N is the
330
+ batch size and D is the hidden size.
331
+ c_0(Variable): The initial cell state is an optional input, default is zero.
332
+ This is a tensor with shape (N x D), where N is the
333
+ batch size. `h_0` and `c_0` can be NULL but only at the same time.
334
+
327
335
param_attr(ParamAttr|None): The parameter attribute for the learnable
328
336
hidden-hidden weights.
329
337
@@ -387,12 +395,20 @@ def dynamic_lstm(input,
387
395
cell = helper .create_tmp_variable (dtype )
388
396
batch_gate = helper .create_tmp_variable (dtype )
389
397
batch_cell_pre_act = helper .create_tmp_variable (dtype )
398
+ inputs = {'Input' : input , 'Weight' : weight , 'Bias' : bias }
399
+ batch_size = input .shape [0 ]
400
+ if h_0 :
401
+ assert h_0 .shape == (batch_size , size ), \
402
+ 'The shape of h0 should be (batch_size, %d)' % size
403
+ inputs ['H0' ] = h_0
404
+ if c_0 :
405
+ assert c_0 .shape == (batch_size , size ), \
406
+ 'The shape of c0 should be (batch_size, %d)' % size
407
+ inputs ['C0' ] = c_0
390
408
391
409
helper .append_op (
392
410
type = 'lstm' ,
393
- inputs = {'Input' : input ,
394
- 'Weight' : weight ,
395
- 'Bias' : bias },
411
+ inputs = inputs ,
396
412
outputs = {
397
413
'Hidden' : hidden ,
398
414
'Cell' : cell ,
@@ -677,11 +693,13 @@ def dynamic_gru(input,
677
693
attr = helper .param_attr , shape = [size , 3 * size ], dtype = dtype )
678
694
bias = helper .create_parameter (
679
695
attr = helper .bias_attr , shape = [1 , 3 * size ], dtype = dtype , is_bias = True )
696
+ batch_size = input .shape [0 ]
680
697
inputs = {'Input' : input , 'Weight' : weight , 'Bias' : bias }
681
698
if h_0 != None :
682
699
assert h_0 .shape == (
683
- size , size ), 'The shape of h0 should be(%d, %d)' % (size , size )
684
- inputs ['h0' ] = h_0
700
+ batch_size , size
701
+ ), 'The shape of h0 should be(batch_size, %d)' % size
702
+ inputs ['H0' ] = h_0
685
703
686
704
hidden = helper .create_tmp_variable (dtype )
687
705
batch_gate = helper .create_tmp_variable (dtype )
0 commit comments