Skip to content

Commit 1511a04

Browse files
authored
Merge pull request #7540 from pkuyym/fix-7533
Add reorder flag for DynamicRNN's memory function.
2 parents 052c05b + c01bb26 commit 1511a04

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

python/paddle/v2/fluid/layers/control_flow.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,20 +1343,44 @@ def __call__(self, *args, **kwargs):
13431343
else:
13441344
return self.outputs
13451345

1346-
def memory(self, init=None, shape=None, value=0.0, dtype='float32'):
1346+
def memory(self,
1347+
init=None,
1348+
shape=None,
1349+
value=0.0,
1350+
need_reorder=False,
1351+
dtype='float32'):
13471352
self._assert_in_rnn_block_('memory')
13481353
if init is not None:
13491354
if not isinstance(init, Variable):
13501355
raise TypeError(
13511356
"The input arg `init` of memory() must be a Variable")
13521357
parent_block = self._parent_block_()
1358+
init_tensor = init
1359+
if need_reorder == True:
1360+
if self.lod_rank_table is None:
1361+
raise ValueError(
1362+
'If set need_reorder to True, make sure step_input be '
1363+
'invoked before '
1364+
'memory(init=init, need_reordered=True, ...).')
1365+
init_reordered = parent_block.create_var(
1366+
name=unique_name('dynamic_rnn_mem_init_reordered'),
1367+
type=core.VarDesc.VarType.LOD_TENSOR,
1368+
dtype=init.dtype)
1369+
parent_block.append_op(
1370+
type='reorder_lod_tensor_by_rank',
1371+
inputs={
1372+
'X': [init_tensor],
1373+
'RankTable': [self.lod_rank_table]
1374+
},
1375+
outputs={'Out': [init_reordered]})
1376+
init_tensor = init_reordered
13531377
mem_array = parent_block.create_var(
13541378
name=unique_name('dynamic_rnn_mem_array'),
13551379
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
13561380
dtype=init.dtype)
13571381
parent_block.append_op(
13581382
type='write_to_array',
1359-
inputs={'X': init,
1383+
inputs={'X': init_tensor,
13601384
'I': self.zero_idx},
13611385
outputs={'Out': mem_array})
13621386
retv = array_read(array=mem_array, i=self.step_idx)

0 commit comments

Comments
 (0)