Skip to content

Commit 81f22bb

Browse files
authored
Merge pull request #11670 from JiayiFeng/refine_test_reader_transpiler
test reader transpiler
2 parents edd947d + 7de8d11 commit 81f22bb

File tree

1 file changed

+99
-1
lines changed

1 file changed

+99
-1
lines changed

python/paddle/fluid/io.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import shutil
1919

2020
from paddle.fluid.evaluator import Evaluator
21-
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
21+
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
2222
from . import core
2323

2424
__all__ = [
@@ -1374,3 +1374,101 @@ def has_success(checkpoint_dir, cur_dir):
13741374
if success_num > current_dir:
13751375
current_dir = success_num
13761376
return current_dir
1377+
1378+
1379+
def get_test_program(filelist, program=None, startup_program=None):
1380+
"""
1381+
Transpile current train program to a program to read test dataset
1382+
if the program is using reader ops like "open_files_op".
1383+
"""
1384+
1385+
def _copy_reader_var_(block, var, new_name=None):
1386+
if new_name == None:
1387+
new_name = var.name
1388+
new_var = block.create_var(
1389+
name=str(new_name), type=core.VarDesc.VarType.READER)
1390+
new_var.desc.set_shapes(var.desc.shapes())
1391+
new_var.desc.set_dtypes(var.desc.dtypes())
1392+
new_var.persistable = True
1393+
return new_var
1394+
1395+
def _get_test_reader_name(train_reader_name):
1396+
return train_reader_name + "_test"
1397+
1398+
def _is_reader_op(op):
1399+
block = op.block
1400+
if "Out" in op.output_names:
1401+
reader_out = block.vars[op.output("Out")[0]]
1402+
if reader_out.type == core.VarDesc.VarType.READER:
1403+
return True
1404+
return False
1405+
1406+
if program == None:
1407+
program = default_main_program()
1408+
if startup_program == None:
1409+
startup_program = default_startup_program()
1410+
startup_block = startup_program.global_block()
1411+
1412+
# 1. find out the orignal reader var name
1413+
startup_reader_op_list = []
1414+
1415+
for op in startup_block.ops:
1416+
if _is_reader_op(op):
1417+
startup_reader_op_list.append(op)
1418+
1419+
if len(startup_reader_op_list) == 0:
1420+
return program
1421+
1422+
root_reader_op = startup_reader_op_list[0]
1423+
train_test_reader_map = {}
1424+
# 2. add operators to startup to read open and read test data files
1425+
for op in startup_reader_op_list:
1426+
assert (len(op.output("Out")) == 1)
1427+
train_reader_name = op.output("Out")[0]
1428+
train_reader = startup_block.vars[train_reader_name]
1429+
test_reader = _copy_reader_var_(
1430+
startup_block,
1431+
train_reader,
1432+
new_name=_get_test_reader_name(train_reader_name))
1433+
train_test_reader_map[train_reader.name] = test_reader
1434+
1435+
test_op_inputs = {}
1436+
for name in op.input_names:
1437+
train_arg_names = op.input(name)
1438+
test_arg_vars = []
1439+
for arg_name in train_arg_names:
1440+
arg_var = train_test_reader_map[
1441+
arg_name] if name == "UnderlyingReader" else startup_block.vars[
1442+
arg_name]
1443+
test_arg_vars.append(arg_var)
1444+
test_op_inputs[name] = test_arg_vars
1445+
1446+
test_op = startup_block.append_op(
1447+
type=op.type,
1448+
inputs=test_op_inputs,
1449+
outputs={'Out': [test_reader]},
1450+
attrs=op.attrs)
1451+
# root reader op's filelist attr for read test files
1452+
if op.type == root_reader_op.type:
1453+
test_op.set_attr("file_names", filelist)
1454+
if op.type == "create_multi_pass_reader":
1455+
test_op.set_attr("pass_num", 1)
1456+
1457+
# 3. rename reader vars in inference program to different name
1458+
# to avoid read from train data.
1459+
main_block = program.global_block()
1460+
for var in main_block.vars.values():
1461+
if var.type == core.VarDesc.VarType.READER:
1462+
main_block.rename_var(
1463+
str(var.name), str(_get_test_reader_name(var.name)))
1464+
1465+
for op in main_block.ops:
1466+
if op.type == root_reader_op.type:
1467+
test_op.set_attr("file_names", filelist)
1468+
if op.type == "create_multi_pass_reader":
1469+
test_op.set_attr("pass_num", 1)
1470+
1471+
startup_program.sync_with_cpp()
1472+
program.sync_with_cpp()
1473+
1474+
return program

0 commit comments

Comments
 (0)