Skip to content

Commit 14e8337

Browse files
author
Yancey
authored
expose h0 in dynamic_lstm (#11391)
* expose h0 in dynamic_lstm * update by comment * update by comment * h0 to H0
1 parent 8453740 commit 14e8337

File tree

1 file changed

+24
-6
lines changed
  • python/paddle/fluid/layers

1 file changed

+24
-6
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,10 @@ def embedding(input,
261261
return tmp
262262

263263

264-
# TODO(qijun): expose H0 and C0
265264
def dynamic_lstm(input,
266265
size,
266+
h_0=None,
267+
c_0=None,
267268
param_attr=None,
268269
bias_attr=None,
269270
use_peepholes=True,
@@ -324,6 +325,13 @@ def dynamic_lstm(input,
324325
(T X 4D), where T is the total time steps in this
325326
mini-batch, D is the hidden size.
326327
size(int): 4 * hidden size.
328+
h_0(Variable): The initial hidden state is an optional input, default is zero.
329+
This is a tensor with shape (N x D), where N is the
330+
batch size and D is the hidden size.
331+
c_0(Variable): The initial cell state is an optional input, default is zero.
332+
This is a tensor with shape (N x D), where N is the
333+
batch size. `h_0` and `c_0` can be NULL but only at the same time.
334+
327335
param_attr(ParamAttr|None): The parameter attribute for the learnable
328336
hidden-hidden weights.
329337
@@ -387,12 +395,20 @@ def dynamic_lstm(input,
387395
cell = helper.create_tmp_variable(dtype)
388396
batch_gate = helper.create_tmp_variable(dtype)
389397
batch_cell_pre_act = helper.create_tmp_variable(dtype)
398+
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
399+
batch_size = input.shape[0]
400+
if h_0:
401+
assert h_0.shape == (batch_size, size), \
402+
'The shape of h0 should be (batch_size, %d)' % size
403+
inputs['H0'] = h_0
404+
if c_0:
405+
assert c_0.shape == (batch_size, size), \
406+
'The shape of c0 should be (batch_size, %d)' % size
407+
inputs['C0'] = c_0
390408

391409
helper.append_op(
392410
type='lstm',
393-
inputs={'Input': input,
394-
'Weight': weight,
395-
'Bias': bias},
411+
inputs=inputs,
396412
outputs={
397413
'Hidden': hidden,
398414
'Cell': cell,
@@ -677,11 +693,13 @@ def dynamic_gru(input,
677693
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
678694
bias = helper.create_parameter(
679695
attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True)
696+
batch_size = input.shape[0]
680697
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
681698
if h_0 != None:
682699
assert h_0.shape == (
683-
size, size), 'The shape of h0 should be(%d, %d)' % (size, size)
684-
inputs['h0'] = h_0
700+
batch_size, size
701+
), 'The shape of h0 should be(batch_size, %d)' % size
702+
inputs['H0'] = h_0
685703

686704
hidden = helper.create_tmp_variable(dtype)
687705
batch_gate = helper.create_tmp_variable(dtype)

0 commit comments

Comments
 (0)