17
17
__all__ = ['parse_args' , ]
18
18
19
19
BENCHMARK_MODELS = [
20
- "machine_translation" , "resnet" , "vgg" , "mnist" , "stacked_dynamic_lstm"
20
+ "machine_translation" , "resnet" , "se_resnext" , "vgg" , "mnist" ,
21
+ "stacked_dynamic_lstm" , "resnet_with_preprocess"
21
22
]
22
23
23
24
@@ -67,12 +68,12 @@ def parse_args():
67
68
'--cpus' ,
68
69
type = int ,
69
70
default = 1 ,
70
- help = 'If cpus > 1, will use ParallelDo to run, else use Executor .' )
71
+ help = 'If cpus > 1, will set ParallelExecutor to use multiple threads .' )
71
72
parser .add_argument (
72
73
'--data_set' ,
73
74
type = str ,
74
75
default = 'flowers' ,
75
- choices = ['cifar10' , 'flowers' ],
76
+ choices = ['cifar10' , 'flowers' , 'imagenet' ],
76
77
help = 'Optional dataset for benchmark.' )
77
78
parser .add_argument (
78
79
'--infer_only' , action = 'store_true' , help = 'If set, run forward only.' )
@@ -122,6 +123,11 @@ def parse_args():
122
123
type = str ,
123
124
default = "" ,
124
125
help = 'Directory that contains all the training recordio files.' )
126
+ parser .add_argument (
127
+ '--test_data_path' ,
128
+ type = str ,
129
+ default = "" ,
130
+ help = 'Directory that contains all the test data (NOT recordio).' )
125
131
parser .add_argument (
126
132
'--use_inference_transpiler' ,
127
133
action = 'store_true' ,
@@ -130,5 +136,9 @@ def parse_args():
130
136
'--no_random' ,
131
137
action = 'store_true' ,
132
138
help = 'If set, keep the random seed and do not shuffle the data.' )
139
+ parser .add_argument (
140
+ '--use_lars' ,
141
+ action = 'store_true' ,
142
+ help = 'If set, use lars for optimizers, ONLY support resnet module.' )
133
143
args = parser .parse_args ()
134
144
return args
0 commit comments