@@ -1310,20 +1310,44 @@ def __call__(self, *args, **kwargs):
1310
1310
else :
1311
1311
return self .outputs
1312
1312
1313
- def memory (self , init = None , shape = None , value = 0.0 , dtype = 'float32' ):
1313
+ def memory (self ,
1314
+ init = None ,
1315
+ shape = None ,
1316
+ value = 0.0 ,
1317
+ need_reorder = False ,
1318
+ dtype = 'float32' ):
1314
1319
self ._assert_in_rnn_block_ ('memory' )
1315
1320
if init is not None :
1316
1321
if not isinstance (init , Variable ):
1317
1322
raise TypeError (
1318
1323
"The input arg `init` of memory() must be a Variable" )
1319
1324
parent_block = self ._parent_block_ ()
1325
+ init_tensor = init
1326
+ if need_reorder == True :
1327
+ if self .lod_rank_table is None :
1328
+ raise ValueError (
1329
+ 'If set need_reorder to True, make sure step_input be '
1330
+ 'invoked before '
1331
+ 'memory(init=init, need_reordered=True, ...).' )
1332
+ init_reordered = parent_block .create_var (
1333
+ name = unique_name ('dynamic_rnn_mem_init_reordered' ),
1334
+ type = core .VarDesc .VarType .LOD_TENSOR ,
1335
+ dtype = init .dtype )
1336
+ parent_block .append_op (
1337
+ type = 'reorder_lod_tensor_by_rank' ,
1338
+ inputs = {
1339
+ 'X' : [init_tensor ],
1340
+ 'RankTable' : [self .lod_rank_table ]
1341
+ },
1342
+ outputs = {'Out' : [init_reordered ]})
1343
+ init_tensor = init_reordered
1320
1344
mem_array = parent_block .create_var (
1321
1345
name = unique_name ('dynamic_rnn_mem_array' ),
1322
1346
type = core .VarDesc .VarType .LOD_TENSOR_ARRAY ,
1323
1347
dtype = init .dtype )
1324
1348
parent_block .append_op (
1325
1349
type = 'write_to_array' ,
1326
- inputs = {'X' : init ,
1350
+ inputs = {'X' : init_tensor ,
1327
1351
'I' : self .zero_idx },
1328
1352
outputs = {'Out' : mem_array })
1329
1353
retv = array_read (array = mem_array , i = self .step_idx )
0 commit comments