@@ -259,10 +259,11 @@ def reset():
259
259
return reader
260
260
261
261
262
- def _copy_reader_var_ (block , var , newname = None ):
263
- if newname == None :
264
- newname = var .name
265
- new_var = block .create_var (name = var .name , type = core .VarDesc .VarType .READER )
262
+ def _copy_reader_var_ (block , var , new_name = None ):
263
+ if new_name == None :
264
+ new_name = var .name
265
+ new_var = block .create_var (
266
+ name = str (new_name ), type = core .VarDesc .VarType .READER )
266
267
new_var .desc .set_shapes (var .desc .shapes ())
267
268
new_var .desc .set_dtypes (var .desc .dtypes ())
268
269
new_var .persistable = True
@@ -693,62 +694,67 @@ def load(out, file_path, load_as_fp16=None):
693
694
helper .append_op (type = "load" , inputs = {}, output = {"Out" : out }, args = attrs )
694
695
695
696
696
- def _is_reader_op (op , block ):
697
- if "Out" in op .output_names :
698
- reader_out = block .vars [op .output ("Out" )[0 ]]
699
- if reader_out .type == core .VarDesc .VarType .READER :
700
- return True
701
- return False
702
-
703
-
704
697
def get_test_program (filelist , program = None , startup_program = None ):
705
698
"""
706
699
Transpile current train program to a program to read test dataset
707
700
if the program is using reader ops like "open_files_op".
708
701
"""
702
+
703
+ def get_test_reader_name (train_reader_name ):
704
+ return train_reader_name + "_test"
705
+
706
+ def is_reader_op (op ):
707
+ block = op .block
708
+ if "Out" in op .output_names :
709
+ reader_out = block .vars [op .output ("Out" )[0 ]]
710
+ if reader_out .type == core .VarDesc .VarType .READER :
711
+ return True
712
+ return False
713
+
709
714
if program == None :
710
715
program = default_main_program ()
711
716
if startup_program == None :
712
717
startup_program = default_startup_program ()
718
+ startup_block = startup_program .global_block ()
713
719
714
720
# 1. find out the orignal reader var name
715
- # open_files_var = None
716
- # train_open_files_op = None
717
721
startup_reader_op_list = []
718
722
719
- for op in startup_program . global_block () .ops :
720
- if _is_reader_op (op , startup_program . global_block () ):
723
+ for op in startup_block .ops :
724
+ if is_reader_op (op ):
721
725
startup_reader_op_list .append (op )
722
726
723
727
if len (startup_reader_op_list ) == 0 :
724
728
return program
725
729
726
730
root_reader_op = startup_reader_op_list [0 ]
727
-
731
+ train_test_reader_map = {}
728
732
# 2. add operators to startup to read open and read test data files
729
733
for op in startup_reader_op_list :
730
- orig_var_name = op .output ("Out" )[0 ]
731
- orig_var = startup_program .global_block ().vars [orig_var_name ]
732
- new_test_var = _copy_reader_var_ (
733
- startup_program .global_block (),
734
- orig_var ,
735
- newname = orig_var_name + "_test" )
736
-
737
- # for open_files like operators have no input.
738
- inputs = None
739
- if "UnderlyingReader" in op .input_names :
740
- orig_input_var_name = op .input ("UnderlyingReader" )[0 ]
741
- orig_input_var = startup_program .global_block ().vars [
742
- orig_input_var_name ]
743
- new_input_var = _copy_reader_var_ (
744
- startup_program .global_block (),
745
- orig_input_var ,
746
- newname = orig_input_var_name + "_test" )
747
- inputs = {"UnderlyingReader" : new_input_var }
748
- test_op = startup_program .global_block ().append_op (
734
+ assert (len (op .output ("Out" )) == 1 )
735
+ train_reader_name = op .output ("Out" )[0 ]
736
+ train_reader = startup_block .vars [train_reader_name ]
737
+ test_reader = _copy_reader_var_ (
738
+ startup_block ,
739
+ train_reader ,
740
+ new_name = get_test_reader_name (train_reader_name ))
741
+ train_test_reader_map [train_reader .name ] = test_reader
742
+
743
+ test_op_inputs = {}
744
+ for name in op .input_names :
745
+ train_arg_names = op .input (name )
746
+ test_arg_vars = []
747
+ for arg_name in train_arg_names :
748
+ arg_var = train_test_reader_map [
749
+ arg_name ] if name == "UnderlyingReader" else startup_block .vars [
750
+ arg_name ]
751
+ test_arg_vars .append (arg_var )
752
+ test_op_inputs [name ] = test_arg_vars
753
+
754
+ test_op = startup_block .append_op (
749
755
type = op .type ,
750
- inputs = inputs ,
751
- outputs = {'Out' : [new_test_var ]},
756
+ inputs = test_op_inputs ,
757
+ outputs = {'Out' : [test_reader ]},
752
758
attrs = op .attrs )
753
759
# root reader op's filelist attr for read test files
754
760
if op .type == root_reader_op .type :
@@ -758,18 +764,19 @@ def get_test_program(filelist, program=None, startup_program=None):
758
764
759
765
# 3. rename reader vars in inference program to different name
760
766
# to avoid read from train data.
761
- origname = root_reader_op .output ("Out" )[0 ]
762
- newname = origname + "_test"
763
- program .global_block ().rename_var (str (origname ), str (newname ))
764
- for op in program .global_block ().ops :
765
- if _is_reader_op (op , program .global_block ()):
766
- origname = op .output ("Out" )[0 ]
767
- newname = origname + "_test"
768
- program .global_block ().rename_var (str (origname ), str (newname ))
767
+ main_block = program .global_block ()
768
+ for var in main_block .vars .values ():
769
+ if var .type == core .VarDesc .VarType .READER :
770
+ main_block .rename_var (
771
+ str (var .name ), str (get_test_reader_name (var .name )))
769
772
773
+ for op in main_block .ops :
774
+ if op .type == root_reader_op .type :
775
+ test_op .set_attr ("file_names" , filelist )
770
776
if op .type == "create_multi_pass_reader" :
771
- op .set_attr ("pass_num" , 1 )
777
+ test_op .set_attr ("pass_num" , 1 )
772
778
779
+ startup_program .sync_with_cpp ()
773
780
program .sync_with_cpp ()
774
781
775
782
return program
0 commit comments