25
25
add_arg ('devices' , str , 'gpu' , "which device used to compress." )
26
26
add_arg ('batch_size' , int , 1 , "train batch size." )
27
27
add_arg ('config_path' , str , None , "path of compression strategy config." )
28
- # yapf: enable
28
+ add_arg ( 'data_dir' , str , None , "path of dataset" )
29
29
30
30
31
+ # yapf: enable
31
32
def reader_wrapper (reader ):
32
33
def gen ():
33
34
for i , data in enumerate (reader ()):
@@ -38,7 +39,8 @@ def gen():
38
39
39
40
40
41
def eval_function (exe , compiled_test_program , test_feed_names , test_fetch_list ):
41
- val_reader = paddle .batch (reader .val (), batch_size = 1 )
42
+
43
+ val_reader = paddle .batch (reader .val (data_dir = data_dir ), batch_size = 1 )
42
44
image = paddle .static .data (
43
45
name = 'x' , shape = [None , 3 , 224 , 224 ], dtype = 'float32' )
44
46
label = paddle .static .data (name = 'label' , shape = [None , 1 ], dtype = 'int64' )
@@ -47,7 +49,6 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
47
49
for batch_id , data in enumerate (val_reader ()):
48
50
# top1_acc, top5_acc
49
51
if len (test_feed_names ) == 1 :
50
- # eval "infer model", which input is image, output is classification probability
51
52
image = data [0 ][0 ].reshape ((1 , 3 , 224 , 224 ))
52
53
label = [[d [1 ]] for d in data ]
53
54
pred = exe .run (compiled_test_program ,
@@ -76,6 +77,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
76
77
fetch_list = test_fetch_list )
77
78
result = [np .mean (r ) for r in result ]
78
79
results .append (result )
80
+ if batch_id % 5000 == 0 :
81
+ print ('Eval iter: ' , batch_id )
79
82
result = np .mean (np .array (results ), axis = 0 )
80
83
return result [0 ]
81
84
@@ -85,8 +88,10 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
85
88
print_arguments (args )
86
89
paddle .enable_static ()
87
90
compress_config , train_config = load_config (args .config_path )
91
+ data_dir = args .data_dir
88
92
89
- train_reader = paddle .batch (reader .train (), batch_size = 64 )
93
+ train_reader = paddle .batch (
94
+ reader .train (data_dir = data_dir ), batch_size = args .batch_size )
90
95
train_dataloader = reader_wrapper (train_reader )
91
96
92
97
ac = AutoCompression (
0 commit comments