@@ -259,7 +259,9 @@ def reset():
259
259
return reader
260
260
261
261
262
- def _copy_reader_var_ (block , var ):
262
+ def _copy_reader_var_ (block , var , newname = None ):
263
+ if newname == None :
264
+ newname = var .name
263
265
new_var = block .create_var (name = var .name , type = core .VarDesc .VarType .READER )
264
266
new_var .desc .set_shapes (var .desc .shapes ())
265
267
new_var .desc .set_dtypes (var .desc .dtypes ())
@@ -691,68 +693,80 @@ def load(out, file_path, load_as_fp16=None):
691
693
helper .append_op (type = "load" , inputs = {}, output = {"Out" : out }, args = attrs )
692
694
693
695
694
- def get_test_program (filelist , test_program = None , startup_program = None ):
695
- """
696
- Transpile current program to read test dataset if the program
697
- is using reader ops like "open_files_op".
696
+ def _is_reader_op (op , block ):
697
+ if "Out" in op .output_names :
698
+ reader_out = block .vars [op .output ("Out" )[0 ]]
699
+ if reader_out .type == core .VarDesc .VarType .READER :
700
+ return True
701
+ return False
698
702
699
- Args:
700
- filelist (list): list of test file paths.
701
- test_program (Program|None): program to run test/evaluation.
702
- default use fluid.default_main_program()
703
- startup_program (Program|None): startup program to change,
704
- default use fluid.default_startup_program()
705
-
706
- Returns:
707
- Program: program for test
703
+
704
+ def get_test_program (filelist , program = None , startup_program = None ):
705
+ """
706
+ Transpile current train program to a program to read test dataset
707
+ if the program is using reader ops like "open_files_op".
708
708
"""
709
- if test_program == None :
709
+ if program == None :
710
710
program = default_main_program ()
711
711
if startup_program == None :
712
712
startup_program = default_startup_program ()
713
713
714
714
# 1. find out the orignal reader var name
715
- open_files_var = None
716
- train_open_files_op = None
715
+ # open_files_var = None
716
+ # train_open_files_op = None
717
+ startup_reader_op_list = []
718
+
717
719
for op in startup_program .global_block ().ops :
718
- if op .type == "open_files" :
719
- train_open_files_op = op
720
- open_files_var_name = op .output ("Out" )[0 ]
721
- open_files_var = startup_program .global_block ().vars [
722
- open_files_var_name ]
723
-
724
- # 2. add operator to startup to read open and read test data files
725
- test_startup_var = startup_program .global_block ().create_var (
726
- name = open_files_var .name + "_test" )
727
-
728
- print ("creating openfiles for test reader: " , train_open_files_op .attrs )
729
- startup_program .global_block ().append_op (
730
- type = 'open_files' ,
731
- outputs = {'Out' : [test_startup_var ]},
732
- attrs = {
733
- 'shape_concat' : train_open_files_op .attrs ["shape_concat" ],
734
- 'lod_levels' : train_open_files_op .attrs ["lod_levels" ],
735
- 'ranks' : train_open_files_op .attrs ["ranks" ],
736
- 'file_names' : filelist ,
737
- 'thread_num' : train_open_files_op .attrs ["thread_num" ],
738
- 'buffer_size' : train_open_files_op .attrs ["buffer_size" ]
739
- })
740
- dtypes = [convert_np_dtype_to_dtype_ (dt ) for dt in ["float32" , "int64" ]]
741
- test_startup_var .desc .set_dtypes (dtypes )
742
- test_startup_var .persistable = True
743
- _copy_reader_var_ (default_main_program ().global_block (), test_startup_var )
720
+ if _is_reader_op (op , startup_program .global_block ()):
721
+ startup_reader_op_list .append (op )
722
+
723
+ if len (startup_reader_op_list ) == 0 :
724
+ return program
725
+
726
+ root_reader_op = startup_reader_op_list [0 ]
727
+
728
+ # 2. add operators to startup to read open and read test data files
729
+ for op in startup_reader_op_list :
730
+ orig_var_name = op .output ("Out" )[0 ]
731
+ orig_var = startup_program .global_block ().vars [orig_var_name ]
732
+ new_test_var = _copy_reader_var_ (
733
+ startup_program .global_block (),
734
+ orig_var ,
735
+ newname = orig_var_name + "_test" )
736
+
737
+ # for open_files like operators have no input.
738
+ inputs = None
739
+ if "UnderlyingReader" in op .input_names :
740
+ orig_input_var_name = op .input ("UnderlyingReader" )[0 ]
741
+ orig_input_var = startup_program .global_block ().vars [
742
+ orig_input_var_name ]
743
+ new_input_var = _copy_reader_var_ (
744
+ startup_program .global_block (),
745
+ orig_input_var ,
746
+ newname = orig_input_var_name + "_test" )
747
+ inputs = {"UnderlyingReader" : new_input_var }
748
+ test_op = startup_program .global_block ().append_op (
749
+ type = op .type ,
750
+ inputs = inputs ,
751
+ outputs = {'Out' : [new_test_var ]},
752
+ attrs = op .attrs )
753
+ # root reader op's filelist attr for read test files
754
+ if op .type == root_reader_op .type :
755
+ test_op .set_attr ("file_names" , filelist )
756
+ if op .type == "create_multi_pass_reader" :
757
+ test_op .set_attr ("pass_num" , 1 )
744
758
745
759
# 3. rename reader vars in inference program to different name
746
760
# to avoid read from train data.
747
- program .global_block ().rename_var (open_files_var .name ,
748
- test_startup_var .name )
761
+ origname = root_reader_op .output ("Out" )[0 ]
762
+ newname = origname + "_test"
763
+ program .global_block ().rename_var (str (origname ), str (newname ))
749
764
for op in program .global_block ().ops :
750
- if "Out" in op .output_names :
751
- op_out_var_name = op .output ("Out" )[0 ]
752
- op_out_var = program .global_block ().vars [op_out_var_name ]
753
- if op_out_var .type == core .VarDesc .VarType .READER :
754
- newname = op_out_var .name + "_test"
755
- program .global_block ().rename_var (op_out_var .name , newname )
765
+ if _is_reader_op (op , program .global_block ()):
766
+ origname = op .output ("Out" )[0 ]
767
+ newname = origname + "_test"
768
+ program .global_block ().rename_var (str (origname ), str (newname ))
769
+
756
770
if op .type == "create_multi_pass_reader" :
757
771
op .set_attr ("pass_num" , 1 )
758
772
0 commit comments