@@ -1343,20 +1343,44 @@ def __call__(self, *args, **kwargs):
1343
1343
else :
1344
1344
return self .outputs
1345
1345
1346
- def memory (self , init = None , shape = None , value = 0.0 , dtype = 'float32' ):
1346
+ def memory (self ,
1347
+ init = None ,
1348
+ shape = None ,
1349
+ value = 0.0 ,
1350
+ need_reorder = False ,
1351
+ dtype = 'float32' ):
1347
1352
self ._assert_in_rnn_block_ ('memory' )
1348
1353
if init is not None :
1349
1354
if not isinstance (init , Variable ):
1350
1355
raise TypeError (
1351
1356
"The input arg `init` of memory() must be a Variable" )
1352
1357
parent_block = self ._parent_block_ ()
1358
+ init_tensor = init
1359
+ if need_reorder == True :
1360
+ if self .lod_rank_table is None :
1361
+ raise ValueError (
1362
+ 'If set need_reorder to True, make sure step_input be '
1363
+ 'invoked before '
1364
+ 'memory(init=init, need_reordered=True, ...).' )
1365
+ init_reordered = parent_block .create_var (
1366
+ name = unique_name ('dynamic_rnn_mem_init_reordered' ),
1367
+ type = core .VarDesc .VarType .LOD_TENSOR ,
1368
+ dtype = init .dtype )
1369
+ parent_block .append_op (
1370
+ type = 'reorder_lod_tensor_by_rank' ,
1371
+ inputs = {
1372
+ 'X' : [init_tensor ],
1373
+ 'RankTable' : [self .lod_rank_table ]
1374
+ },
1375
+ outputs = {'Out' : [init_reordered ]})
1376
+ init_tensor = init_reordered
1353
1377
mem_array = parent_block .create_var (
1354
1378
name = unique_name ('dynamic_rnn_mem_array' ),
1355
1379
type = core .VarDesc .VarType .LOD_TENSOR_ARRAY ,
1356
1380
dtype = init .dtype )
1357
1381
parent_block .append_op (
1358
1382
type = 'write_to_array' ,
1359
- inputs = {'X' : init ,
1383
+ inputs = {'X' : init_tensor ,
1360
1384
'I' : self .zero_idx },
1361
1385
outputs = {'Out' : mem_array })
1362
1386
retv = array_read (array = mem_array , i = self .step_idx )
0 commit comments