Skip to content

Commit c01bb26

Browse files
committed
Add reorder flag for DynamicRNN's memory function.
1 parent cb6b468 commit c01bb26

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

python/paddle/v2/fluid/layers/control_flow.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,20 +1310,44 @@ def __call__(self, *args, **kwargs):
13101310
else:
13111311
return self.outputs
13121312

1313-
def memory(self, init=None, shape=None, value=0.0, dtype='float32'):
1313+
def memory(self,
1314+
init=None,
1315+
shape=None,
1316+
value=0.0,
1317+
need_reorder=False,
1318+
dtype='float32'):
13141319
self._assert_in_rnn_block_('memory')
13151320
if init is not None:
13161321
if not isinstance(init, Variable):
13171322
raise TypeError(
13181323
"The input arg `init` of memory() must be a Variable")
13191324
parent_block = self._parent_block_()
1325+
init_tensor = init
1326+
if need_reorder == True:
1327+
if self.lod_rank_table is None:
1328+
raise ValueError(
1329+
'If set need_reorder to True, make sure step_input be '
1330+
'invoked before '
1331+
'memory(init=init, need_reordered=True, ...).')
1332+
init_reordered = parent_block.create_var(
1333+
name=unique_name('dynamic_rnn_mem_init_reordered'),
1334+
type=core.VarDesc.VarType.LOD_TENSOR,
1335+
dtype=init.dtype)
1336+
parent_block.append_op(
1337+
type='reorder_lod_tensor_by_rank',
1338+
inputs={
1339+
'X': [init_tensor],
1340+
'RankTable': [self.lod_rank_table]
1341+
},
1342+
outputs={'Out': [init_reordered]})
1343+
init_tensor = init_reordered
13201344
mem_array = parent_block.create_var(
13211345
name=unique_name('dynamic_rnn_mem_array'),
13221346
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
13231347
dtype=init.dtype)
13241348
parent_block.append_op(
13251349
type='write_to_array',
1326-
inputs={'X': init,
1350+
inputs={'X': init_tensor,
13271351
'I': self.zero_idx},
13281352
outputs={'Out': mem_array})
13291353
retv = array_read(array=mem_array, i=self.step_idx)

0 commit comments

Comments
 (0)