|
17 | 17 | import shutil
|
18 | 18 |
|
19 | 19 | from paddle.fluid.evaluator import Evaluator
|
20 |
| -from paddle.fluid.framework import Program, Parameter, default_main_program, Variable |
| 20 | +from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable |
21 | 21 | from . import core
|
22 | 22 |
|
23 | 23 | __all__ = [
|
@@ -744,3 +744,101 @@ def has_success(checkpoint_dir, cur_dir):
|
744 | 744 | if success_num > current_dir:
|
745 | 745 | current_dir = success_num
|
746 | 746 | return current_dir
|
| 747 | + |
| 748 | + |
| 749 | +def get_test_program(filelist, program=None, startup_program=None): |
| 750 | + """ |
| 751 | + Transpile current train program to a program to read test dataset |
| 752 | + if the program is using reader ops like "open_files_op". |
| 753 | + """ |
| 754 | + |
| 755 | + def _copy_reader_var_(block, var, new_name=None): |
| 756 | + if new_name == None: |
| 757 | + new_name = var.name |
| 758 | + new_var = block.create_var( |
| 759 | + name=str(new_name), type=core.VarDesc.VarType.READER) |
| 760 | + new_var.desc.set_shapes(var.desc.shapes()) |
| 761 | + new_var.desc.set_dtypes(var.desc.dtypes()) |
| 762 | + new_var.persistable = True |
| 763 | + return new_var |
| 764 | + |
| 765 | + def get_test_reader_name(train_reader_name): |
| 766 | + return train_reader_name + "_test" |
| 767 | + |
| 768 | + def is_reader_op(op): |
| 769 | + block = op.block |
| 770 | + if "Out" in op.output_names: |
| 771 | + reader_out = block.vars[op.output("Out")[0]] |
| 772 | + if reader_out.type == core.VarDesc.VarType.READER: |
| 773 | + return True |
| 774 | + return False |
| 775 | + |
| 776 | + if program == None: |
| 777 | + program = default_main_program() |
| 778 | + if startup_program == None: |
| 779 | + startup_program = default_startup_program() |
| 780 | + startup_block = startup_program.global_block() |
| 781 | + |
| 782 | + # 1. find out the orignal reader var name |
| 783 | + startup_reader_op_list = [] |
| 784 | + |
| 785 | + for op in startup_block.ops: |
| 786 | + if is_reader_op(op): |
| 787 | + startup_reader_op_list.append(op) |
| 788 | + |
| 789 | + if len(startup_reader_op_list) == 0: |
| 790 | + return program |
| 791 | + |
| 792 | + root_reader_op = startup_reader_op_list[0] |
| 793 | + train_test_reader_map = {} |
| 794 | + # 2. add operators to startup to read open and read test data files |
| 795 | + for op in startup_reader_op_list: |
| 796 | + assert (len(op.output("Out")) == 1) |
| 797 | + train_reader_name = op.output("Out")[0] |
| 798 | + train_reader = startup_block.vars[train_reader_name] |
| 799 | + test_reader = _copy_reader_var_( |
| 800 | + startup_block, |
| 801 | + train_reader, |
| 802 | + new_name=get_test_reader_name(train_reader_name)) |
| 803 | + train_test_reader_map[train_reader.name] = test_reader |
| 804 | + |
| 805 | + test_op_inputs = {} |
| 806 | + for name in op.input_names: |
| 807 | + train_arg_names = op.input(name) |
| 808 | + test_arg_vars = [] |
| 809 | + for arg_name in train_arg_names: |
| 810 | + arg_var = train_test_reader_map[ |
| 811 | + arg_name] if name == "UnderlyingReader" else startup_block.vars[ |
| 812 | + arg_name] |
| 813 | + test_arg_vars.append(arg_var) |
| 814 | + test_op_inputs[name] = test_arg_vars |
| 815 | + |
| 816 | + test_op = startup_block.append_op( |
| 817 | + type=op.type, |
| 818 | + inputs=test_op_inputs, |
| 819 | + outputs={'Out': [test_reader]}, |
| 820 | + attrs=op.attrs) |
| 821 | + # root reader op's filelist attr for read test files |
| 822 | + if op.type == root_reader_op.type: |
| 823 | + test_op.set_attr("file_names", filelist) |
| 824 | + if op.type == "create_multi_pass_reader": |
| 825 | + test_op.set_attr("pass_num", 1) |
| 826 | + |
| 827 | + # 3. rename reader vars in inference program to different name |
| 828 | + # to avoid read from train data. |
| 829 | + main_block = program.global_block() |
| 830 | + for var in main_block.vars.values(): |
| 831 | + if var.type == core.VarDesc.VarType.READER: |
| 832 | + main_block.rename_var( |
| 833 | + str(var.name), str(get_test_reader_name(var.name))) |
| 834 | + |
| 835 | + for op in main_block.ops: |
| 836 | + if op.type == root_reader_op.type: |
| 837 | + test_op.set_attr("file_names", filelist) |
| 838 | + if op.type == "create_multi_pass_reader": |
| 839 | + test_op.set_attr("pass_num", 1) |
| 840 | + |
| 841 | + startup_program.sync_with_cpp() |
| 842 | + program.sync_with_cpp() |
| 843 | + |
| 844 | + return program |
0 commit comments