Skip to content

Commit a424f5a

Browse files
author
yi.wu
committed
polish reader op for test
1 parent 343c195 commit a424f5a

File tree

1 file changed

+65
-51
lines changed
  • python/paddle/fluid/layers

1 file changed

+65
-51
lines changed

python/paddle/fluid/layers/io.py

Lines changed: 65 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ def reset():
259259
return reader
260260

261261

262-
def _copy_reader_var_(block, var):
262+
def _copy_reader_var_(block, var, newname=None):
263+
if newname == None:
264+
newname = var.name
263265
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
264266
new_var.desc.set_shapes(var.desc.shapes())
265267
new_var.desc.set_dtypes(var.desc.dtypes())
@@ -691,68 +693,80 @@ def load(out, file_path, load_as_fp16=None):
691693
helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs)
692694

693695

694-
def get_test_program(filelist, test_program=None, startup_program=None):
695-
"""
696-
Transpile current program to read test dataset if the program
697-
is using reader ops like "open_files_op".
696+
def _is_reader_op(op, block):
697+
if "Out" in op.output_names:
698+
reader_out = block.vars[op.output("Out")[0]]
699+
if reader_out.type == core.VarDesc.VarType.READER:
700+
return True
701+
return False
698702

699-
Args:
700-
filelist (list): list of test file paths.
701-
test_program (Program|None): program to run test/evaluation.
702-
default use fluid.default_main_program()
703-
startup_program (Program|None): startup program to change,
704-
default use fluid.default_startup_program()
705-
706-
Returns:
707-
Program: program for test
703+
704+
def get_test_program(filelist, program=None, startup_program=None):
705+
"""
706+
Transpile current train program to a program to read test dataset
707+
if the program is using reader ops like "open_files_op".
708708
"""
709-
if test_program == None:
709+
if program == None:
710710
program = default_main_program()
711711
if startup_program == None:
712712
startup_program = default_startup_program()
713713

714714
# 1. find out the orignal reader var name
715-
open_files_var = None
716-
train_open_files_op = None
715+
# open_files_var = None
716+
# train_open_files_op = None
717+
startup_reader_op_list = []
718+
717719
for op in startup_program.global_block().ops:
718-
if op.type == "open_files":
719-
train_open_files_op = op
720-
open_files_var_name = op.output("Out")[0]
721-
open_files_var = startup_program.global_block().vars[
722-
open_files_var_name]
723-
724-
# 2. add operator to startup to read open and read test data files
725-
test_startup_var = startup_program.global_block().create_var(
726-
name=open_files_var.name + "_test")
727-
728-
print("creating openfiles for test reader: ", train_open_files_op.attrs)
729-
startup_program.global_block().append_op(
730-
type='open_files',
731-
outputs={'Out': [test_startup_var]},
732-
attrs={
733-
'shape_concat': train_open_files_op.attrs["shape_concat"],
734-
'lod_levels': train_open_files_op.attrs["lod_levels"],
735-
'ranks': train_open_files_op.attrs["ranks"],
736-
'file_names': filelist,
737-
'thread_num': train_open_files_op.attrs["thread_num"],
738-
'buffer_size': train_open_files_op.attrs["buffer_size"]
739-
})
740-
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in ["float32", "int64"]]
741-
test_startup_var.desc.set_dtypes(dtypes)
742-
test_startup_var.persistable = True
743-
_copy_reader_var_(default_main_program().global_block(), test_startup_var)
720+
if _is_reader_op(op, startup_program.global_block()):
721+
startup_reader_op_list.append(op)
722+
723+
if len(startup_reader_op_list) == 0:
724+
return program
725+
726+
root_reader_op = startup_reader_op_list[0]
727+
728+
# 2. add operators to startup to read open and read test data files
729+
for op in startup_reader_op_list:
730+
orig_var_name = op.output("Out")[0]
731+
orig_var = startup_program.global_block().vars[orig_var_name]
732+
new_test_var = _copy_reader_var_(
733+
startup_program.global_block(),
734+
orig_var,
735+
newname=orig_var_name + "_test")
736+
737+
# for open_files like operators have no input.
738+
inputs = None
739+
if "UnderlyingReader" in op.input_names:
740+
orig_input_var_name = op.input("UnderlyingReader")[0]
741+
orig_input_var = startup_program.global_block().vars[
742+
orig_input_var_name]
743+
new_input_var = _copy_reader_var_(
744+
startup_program.global_block(),
745+
orig_input_var,
746+
newname=orig_input_var_name + "_test")
747+
inputs = {"UnderlyingReader": new_input_var}
748+
test_op = startup_program.global_block().append_op(
749+
type=op.type,
750+
inputs=inputs,
751+
outputs={'Out': [new_test_var]},
752+
attrs=op.attrs)
753+
# root reader op's filelist attr for read test files
754+
if op.type == root_reader_op.type:
755+
test_op.set_attr("file_names", filelist)
756+
if op.type == "create_multi_pass_reader":
757+
test_op.set_attr("pass_num", 1)
744758

745759
# 3. rename reader vars in inference program to different name
746760
# to avoid read from train data.
747-
program.global_block().rename_var(open_files_var.name,
748-
test_startup_var.name)
761+
origname = root_reader_op.output("Out")[0]
762+
newname = origname + "_test"
763+
program.global_block().rename_var(str(origname), str(newname))
749764
for op in program.global_block().ops:
750-
if "Out" in op.output_names:
751-
op_out_var_name = op.output("Out")[0]
752-
op_out_var = program.global_block().vars[op_out_var_name]
753-
if op_out_var.type == core.VarDesc.VarType.READER:
754-
newname = op_out_var.name + "_test"
755-
program.global_block().rename_var(op_out_var.name, newname)
765+
if _is_reader_op(op, program.global_block()):
766+
origname = op.output("Out")[0]
767+
newname = origname + "_test"
768+
program.global_block().rename_var(str(origname), str(newname))
769+
756770
if op.type == "create_multi_pass_reader":
757771
op.set_attr("pass_num", 1)
758772

0 commit comments

Comments
 (0)