Skip to content

Commit 7d3d722

Browse files
committed
refine get_test_program
1 parent a424f5a commit 7d3d722

File tree

1 file changed

+54
-47
lines changed
  • python/paddle/fluid/layers

1 file changed

+54
-47
lines changed

python/paddle/fluid/layers/io.py

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,11 @@ def reset():
259259
return reader
260260

261261

262-
def _copy_reader_var_(block, var, newname=None):
263-
if newname == None:
264-
newname = var.name
265-
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
262+
def _copy_reader_var_(block, var, new_name=None):
263+
if new_name == None:
264+
new_name = var.name
265+
new_var = block.create_var(
266+
name=str(new_name), type=core.VarDesc.VarType.READER)
266267
new_var.desc.set_shapes(var.desc.shapes())
267268
new_var.desc.set_dtypes(var.desc.dtypes())
268269
new_var.persistable = True
@@ -693,62 +694,67 @@ def load(out, file_path, load_as_fp16=None):
693694
helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs)
694695

695696

696-
def _is_reader_op(op, block):
697-
if "Out" in op.output_names:
698-
reader_out = block.vars[op.output("Out")[0]]
699-
if reader_out.type == core.VarDesc.VarType.READER:
700-
return True
701-
return False
702-
703-
704697
def get_test_program(filelist, program=None, startup_program=None):
705698
"""
706699
Transpile current train program to a program to read test dataset
707700
if the program is using reader ops like "open_files_op".
708701
"""
702+
703+
def get_test_reader_name(train_reader_name):
704+
return train_reader_name + "_test"
705+
706+
def is_reader_op(op):
707+
block = op.block
708+
if "Out" in op.output_names:
709+
reader_out = block.vars[op.output("Out")[0]]
710+
if reader_out.type == core.VarDesc.VarType.READER:
711+
return True
712+
return False
713+
709714
if program == None:
710715
program = default_main_program()
711716
if startup_program == None:
712717
startup_program = default_startup_program()
718+
startup_block = startup_program.global_block()
713719

714720
# 1. find out the orignal reader var name
715-
# open_files_var = None
716-
# train_open_files_op = None
717721
startup_reader_op_list = []
718722

719-
for op in startup_program.global_block().ops:
720-
if _is_reader_op(op, startup_program.global_block()):
723+
for op in startup_block.ops:
724+
if is_reader_op(op):
721725
startup_reader_op_list.append(op)
722726

723727
if len(startup_reader_op_list) == 0:
724728
return program
725729

726730
root_reader_op = startup_reader_op_list[0]
727-
731+
train_test_reader_map = {}
728732
# 2. add operators to startup to read open and read test data files
729733
for op in startup_reader_op_list:
730-
orig_var_name = op.output("Out")[0]
731-
orig_var = startup_program.global_block().vars[orig_var_name]
732-
new_test_var = _copy_reader_var_(
733-
startup_program.global_block(),
734-
orig_var,
735-
newname=orig_var_name + "_test")
736-
737-
# for open_files like operators have no input.
738-
inputs = None
739-
if "UnderlyingReader" in op.input_names:
740-
orig_input_var_name = op.input("UnderlyingReader")[0]
741-
orig_input_var = startup_program.global_block().vars[
742-
orig_input_var_name]
743-
new_input_var = _copy_reader_var_(
744-
startup_program.global_block(),
745-
orig_input_var,
746-
newname=orig_input_var_name + "_test")
747-
inputs = {"UnderlyingReader": new_input_var}
748-
test_op = startup_program.global_block().append_op(
734+
assert (len(op.output("Out")) == 1)
735+
train_reader_name = op.output("Out")[0]
736+
train_reader = startup_block.vars[train_reader_name]
737+
test_reader = _copy_reader_var_(
738+
startup_block,
739+
train_reader,
740+
new_name=get_test_reader_name(train_reader_name))
741+
train_test_reader_map[train_reader.name] = test_reader
742+
743+
test_op_inputs = {}
744+
for name in op.input_names:
745+
train_arg_names = op.input(name)
746+
test_arg_vars = []
747+
for arg_name in train_arg_names:
748+
arg_var = train_test_reader_map[
749+
arg_name] if name == "UnderlyingReader" else startup_block.vars[
750+
arg_name]
751+
test_arg_vars.append(arg_var)
752+
test_op_inputs[name] = test_arg_vars
753+
754+
test_op = startup_block.append_op(
749755
type=op.type,
750-
inputs=inputs,
751-
outputs={'Out': [new_test_var]},
756+
inputs=test_op_inputs,
757+
outputs={'Out': [test_reader]},
752758
attrs=op.attrs)
753759
# root reader op's filelist attr for read test files
754760
if op.type == root_reader_op.type:
@@ -758,18 +764,19 @@ def get_test_program(filelist, program=None, startup_program=None):
758764

759765
# 3. rename reader vars in inference program to different name
760766
# to avoid read from train data.
761-
origname = root_reader_op.output("Out")[0]
762-
newname = origname + "_test"
763-
program.global_block().rename_var(str(origname), str(newname))
764-
for op in program.global_block().ops:
765-
if _is_reader_op(op, program.global_block()):
766-
origname = op.output("Out")[0]
767-
newname = origname + "_test"
768-
program.global_block().rename_var(str(origname), str(newname))
767+
main_block = program.global_block()
768+
for var in main_block.vars.values():
769+
if var.type == core.VarDesc.VarType.READER:
770+
main_block.rename_var(
771+
str(var.name), str(get_test_reader_name(var.name)))
769772

773+
for op in main_block.ops:
774+
if op.type == root_reader_op.type:
775+
test_op.set_attr("file_names", filelist)
770776
if op.type == "create_multi_pass_reader":
771-
op.set_attr("pass_num", 1)
777+
test_op.set_attr("pass_num", 1)
772778

779+
startup_program.sync_with_cpp()
773780
program.sync_with_cpp()
774781

775782
return program

0 commit comments

Comments
 (0)