31
31
add_arg ('model_path' , str , "./pruning/checkpoints/resnet50/2/eval_model/" , "Whether to use pretrained model." )
32
32
add_arg ('model_name' , str , None , "model filename for inference model" )
33
33
add_arg ('params_name' , str , None , "params filename for inference model" )
34
+ add_arg ('batch_size' , int , 64 , "Minibatch size." )
34
35
# yapf: enable
35
36
36
37
37
38
def eval (args ):
38
- # parameters from arguments
39
-
40
39
place = paddle .CUDAPlace (0 ) if args .use_gpu else paddle .CPUPlace ()
41
40
exe = paddle .static .Executor (place )
42
41
@@ -45,23 +44,29 @@ def eval(args):
45
44
exe ,
46
45
model_filename = args .model_name ,
47
46
params_filename = args .params_name )
48
- val_reader = paddle . batch ( reader .val (), batch_size = 1 )
47
+ val_dataset = reader .ImageNetDataset ( mode = 'val' )
49
48
50
49
image = paddle .static .data (
51
50
name = 'image' , shape = [None , 3 , 224 , 224 ], dtype = 'float32' )
52
51
label = paddle .static .data (name = 'label' , shape = [None , 1 ], dtype = 'int64' )
53
52
54
- valid_loader = paddle .io .DataLoader .from_generator (
55
- feed_list = [image ], capacity = 512 , use_double_buffer = True , iterable = True )
56
- valid_loader .set_sample_list_generator (val_reader , place )
53
+ val_loader = paddle .io .DataLoader (
54
+ val_dataset ,
55
+ places = place ,
56
+ feed_list = [image , label ],
57
+ drop_last = False ,
58
+ return_list = True ,
59
+ batch_size = args .batch_size ,
60
+ use_shared_memory = True ,
61
+ shuffle = False )
57
62
58
63
results = []
59
- for batch_id , data in enumerate (val_reader ()):
64
+ for batch_id , data in enumerate (val_loader ()):
60
65
# top1_acc, top5_acc
61
66
if len (feed_target_names ) == 1 :
62
67
# eval "infer model", which input is image, output is classification probability
63
- image = data [0 ][ 0 ]. reshape (( 1 , 3 , 224 , 224 ))
64
- label = [[ d [ 1 ]] for d in data ]
68
+ image = data [0 ]
69
+ label = data [ 1 ]
65
70
pred = exe .run (val_program ,
66
71
feed = {feed_target_names [0 ]: image },
67
72
fetch_list = fetch_targets )
@@ -79,8 +84,8 @@ def eval(args):
79
84
results .append ([top_1 , top_5 ])
80
85
else :
81
86
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
82
- image = data [0 ][ 0 ]. reshape (( 1 , 3 , 224 , 224 ))
83
- label = [[ d [ 1 ]] for d in data ]
87
+ image = data [0 ]
88
+ label = data [ 1 ]
84
89
result = exe .run (val_program ,
85
90
feed = {
86
91
feed_target_names [0 ]: image ,
@@ -89,6 +94,8 @@ def eval(args):
89
94
fetch_list = fetch_targets )
90
95
result = [np .mean (r ) for r in result ]
91
96
results .append (result )
97
+ if batch_id % 100 == 0 :
98
+ print ('Eval iter: ' , batch_id )
92
99
result = np .mean (np .array (results ), axis = 0 )
93
100
print ("top1_acc/top5_acc= {}" .format (result ))
94
101
sys .stdout .flush ()
0 commit comments