Skip to content

Commit c15c6c1

Browse files
committed
move get_test_program to non-layer io.py
1 parent 7d3d722 commit c15c6c1

File tree

2 files changed

+99
-89
lines changed

2 files changed

+99
-89
lines changed

python/paddle/fluid/io.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import shutil
1818

1919
from paddle.fluid.evaluator import Evaluator
20-
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
20+
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
2121
from . import core
2222

2323
__all__ = [
@@ -744,3 +744,101 @@ def has_success(checkpoint_dir, cur_dir):
744744
if success_num > current_dir:
745745
current_dir = success_num
746746
return current_dir
747+
748+
749+
def get_test_program(filelist, program=None, startup_program=None):
750+
"""
751+
Transpile current train program to a program to read test dataset
752+
if the program is using reader ops like "open_files_op".
753+
"""
754+
755+
def _copy_reader_var_(block, var, new_name=None):
756+
if new_name == None:
757+
new_name = var.name
758+
new_var = block.create_var(
759+
name=str(new_name), type=core.VarDesc.VarType.READER)
760+
new_var.desc.set_shapes(var.desc.shapes())
761+
new_var.desc.set_dtypes(var.desc.dtypes())
762+
new_var.persistable = True
763+
return new_var
764+
765+
def get_test_reader_name(train_reader_name):
766+
return train_reader_name + "_test"
767+
768+
def is_reader_op(op):
769+
block = op.block
770+
if "Out" in op.output_names:
771+
reader_out = block.vars[op.output("Out")[0]]
772+
if reader_out.type == core.VarDesc.VarType.READER:
773+
return True
774+
return False
775+
776+
if program == None:
777+
program = default_main_program()
778+
if startup_program == None:
779+
startup_program = default_startup_program()
780+
startup_block = startup_program.global_block()
781+
782+
# 1. find out the orignal reader var name
783+
startup_reader_op_list = []
784+
785+
for op in startup_block.ops:
786+
if is_reader_op(op):
787+
startup_reader_op_list.append(op)
788+
789+
if len(startup_reader_op_list) == 0:
790+
return program
791+
792+
root_reader_op = startup_reader_op_list[0]
793+
train_test_reader_map = {}
794+
# 2. add operators to startup to read open and read test data files
795+
for op in startup_reader_op_list:
796+
assert (len(op.output("Out")) == 1)
797+
train_reader_name = op.output("Out")[0]
798+
train_reader = startup_block.vars[train_reader_name]
799+
test_reader = _copy_reader_var_(
800+
startup_block,
801+
train_reader,
802+
new_name=get_test_reader_name(train_reader_name))
803+
train_test_reader_map[train_reader.name] = test_reader
804+
805+
test_op_inputs = {}
806+
for name in op.input_names:
807+
train_arg_names = op.input(name)
808+
test_arg_vars = []
809+
for arg_name in train_arg_names:
810+
arg_var = train_test_reader_map[
811+
arg_name] if name == "UnderlyingReader" else startup_block.vars[
812+
arg_name]
813+
test_arg_vars.append(arg_var)
814+
test_op_inputs[name] = test_arg_vars
815+
816+
test_op = startup_block.append_op(
817+
type=op.type,
818+
inputs=test_op_inputs,
819+
outputs={'Out': [test_reader]},
820+
attrs=op.attrs)
821+
# root reader op's filelist attr for read test files
822+
if op.type == root_reader_op.type:
823+
test_op.set_attr("file_names", filelist)
824+
if op.type == "create_multi_pass_reader":
825+
test_op.set_attr("pass_num", 1)
826+
827+
# 3. rename reader vars in inference program to different name
828+
# to avoid read from train data.
829+
main_block = program.global_block()
830+
for var in main_block.vars.values():
831+
if var.type == core.VarDesc.VarType.READER:
832+
main_block.rename_var(
833+
str(var.name), str(get_test_reader_name(var.name)))
834+
835+
for op in main_block.ops:
836+
if op.type == root_reader_op.type:
837+
test_op.set_attr("file_names", filelist)
838+
if op.type == "create_multi_pass_reader":
839+
test_op.set_attr("pass_num", 1)
840+
841+
startup_program.sync_with_cpp()
842+
program.sync_with_cpp()
843+
844+
return program

python/paddle/fluid/layers/io.py

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -692,91 +692,3 @@ def load(out, file_path, load_as_fp16=None):
692692
if load_as_fp16 is not None:
693693
attrs['load_as_fp16'] = load_as_fp16
694694
helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs)
695-
696-
697-
def get_test_program(filelist, program=None, startup_program=None):
698-
"""
699-
Transpile current train program to a program to read test dataset
700-
if the program is using reader ops like "open_files_op".
701-
"""
702-
703-
def get_test_reader_name(train_reader_name):
704-
return train_reader_name + "_test"
705-
706-
def is_reader_op(op):
707-
block = op.block
708-
if "Out" in op.output_names:
709-
reader_out = block.vars[op.output("Out")[0]]
710-
if reader_out.type == core.VarDesc.VarType.READER:
711-
return True
712-
return False
713-
714-
if program == None:
715-
program = default_main_program()
716-
if startup_program == None:
717-
startup_program = default_startup_program()
718-
startup_block = startup_program.global_block()
719-
720-
# 1. find out the orignal reader var name
721-
startup_reader_op_list = []
722-
723-
for op in startup_block.ops:
724-
if is_reader_op(op):
725-
startup_reader_op_list.append(op)
726-
727-
if len(startup_reader_op_list) == 0:
728-
return program
729-
730-
root_reader_op = startup_reader_op_list[0]
731-
train_test_reader_map = {}
732-
# 2. add operators to startup to read open and read test data files
733-
for op in startup_reader_op_list:
734-
assert (len(op.output("Out")) == 1)
735-
train_reader_name = op.output("Out")[0]
736-
train_reader = startup_block.vars[train_reader_name]
737-
test_reader = _copy_reader_var_(
738-
startup_block,
739-
train_reader,
740-
new_name=get_test_reader_name(train_reader_name))
741-
train_test_reader_map[train_reader.name] = test_reader
742-
743-
test_op_inputs = {}
744-
for name in op.input_names:
745-
train_arg_names = op.input(name)
746-
test_arg_vars = []
747-
for arg_name in train_arg_names:
748-
arg_var = train_test_reader_map[
749-
arg_name] if name == "UnderlyingReader" else startup_block.vars[
750-
arg_name]
751-
test_arg_vars.append(arg_var)
752-
test_op_inputs[name] = test_arg_vars
753-
754-
test_op = startup_block.append_op(
755-
type=op.type,
756-
inputs=test_op_inputs,
757-
outputs={'Out': [test_reader]},
758-
attrs=op.attrs)
759-
# root reader op's filelist attr for read test files
760-
if op.type == root_reader_op.type:
761-
test_op.set_attr("file_names", filelist)
762-
if op.type == "create_multi_pass_reader":
763-
test_op.set_attr("pass_num", 1)
764-
765-
# 3. rename reader vars in inference program to different name
766-
# to avoid read from train data.
767-
main_block = program.global_block()
768-
for var in main_block.vars.values():
769-
if var.type == core.VarDesc.VarType.READER:
770-
main_block.rename_var(
771-
str(var.name), str(get_test_reader_name(var.name)))
772-
773-
for op in main_block.ops:
774-
if op.type == root_reader_op.type:
775-
test_op.set_attr("file_names", filelist)
776-
if op.type == "create_multi_pass_reader":
777-
test_op.set_attr("pass_num", 1)
778-
779-
startup_program.sync_with_cpp()
780-
program.sync_with_cpp()
781-
782-
return program

0 commit comments

Comments
 (0)