6
6
from multiprocessing import Pool , cpu_count
7
7
from os .path import expanduser
8
8
from time import sleep
9
- from typing import Tuple
9
+ from typing import Tuple , Optional
10
10
11
+ from atcodertools .codegen .code_gen_config import CodeGenConfig
11
12
from atcodertools .codegen .cpp_code_generator import CppCodeGenerator
12
13
from atcodertools .codegen .java_code_generator import JavaCodeGenerator
13
14
from atcodertools .fileutils .create_contest_file import create_examples , create_code_from_prediction_result
@@ -43,7 +44,9 @@ def prepare_procedure(atcoder_client: AtCoderClient,
43
44
workspace_root_path : str ,
44
45
template_code_path : str ,
45
46
replacement_code_path : str ,
46
- lang : str ):
47
+ lang : str ,
48
+ config : CodeGenConfig ,
49
+ ):
47
50
pid = problem .get_alphabet ()
48
51
workspace_dir_path = os .path .join (
49
52
workspace_root_path ,
@@ -112,7 +115,7 @@ def emit_info(text):
112
115
113
116
create_code_from_prediction_result (
114
117
result ,
115
- gen_class (template ),
118
+ gen_class (template , config ),
116
119
code_file_path )
117
120
emit_info (
118
121
"Prediction succeeded -- Saved auto-generated code to '{}'" .format (code_file_path ))
@@ -141,11 +144,11 @@ def emit_info(text):
141
144
emit_info ("Saved metadata to {}" .format (metadata_path ))
142
145
143
146
144
- def func (argv : Tuple [AtCoderClient , Problem , str , str , str , str ]):
145
- atcoder_client , problem , workspace_root_path , template_code_path , replacement_code_path , lang = argv
147
+ def func (argv : Tuple [AtCoderClient , Problem , str , str , str , str , CodeGenConfig ]):
148
+ atcoder_client , problem , workspace_root_path , template_code_path , replacement_code_path , lang , config = argv
146
149
prepare_procedure (
147
150
atcoder_client , problem , workspace_root_path , template_code_path ,
148
- replacement_code_path , lang )
151
+ replacement_code_path , lang , config )
149
152
150
153
151
154
def prepare_workspace (atcoder_client : AtCoderClient ,
@@ -154,7 +157,9 @@ def prepare_workspace(atcoder_client: AtCoderClient,
154
157
template_code_path : str ,
155
158
replacement_code_path : str ,
156
159
lang : str ,
157
- parallel : bool ):
160
+ parallel : bool ,
161
+ config : CodeGenConfig ,
162
+ ):
158
163
retry_duration = 1.5
159
164
while True :
160
165
problem_list = atcoder_client .download_problem_list (
@@ -165,7 +170,7 @@ def prepare_workspace(atcoder_client: AtCoderClient,
165
170
logging .warning (
166
171
"Failed to fetch. Will retry in {} seconds" .format (retry_duration ))
167
172
168
- tasks = [(atcoder_client , problem , workspace_root_path , template_code_path , replacement_code_path , lang ) for
173
+ tasks = [(atcoder_client , problem , workspace_root_path , template_code_path , replacement_code_path , lang , config ) for
169
174
problem in problem_list ]
170
175
if parallel :
171
176
thread_pool = Pool (processes = cpu_count ())
@@ -202,41 +207,61 @@ def check_lang(lang: str):
202
207
return lang
203
208
204
209
210
+ USER_CONFIG_PATH = os .path .join (
211
+ expanduser ("~" ), ".atcodertools.toml" )
212
+ DEFAULT_CONFIG_PATH = os .path .abspath (
213
+ os .path .join (script_dir_path , "./atcodertools-default.toml" ))
214
+
215
+
216
+ def get_code_gen_config (config_path : Optional [str ] = None ):
217
+ def _load (path : str ):
218
+ logging .info ("Going to load {} as config" .format (path ))
219
+ with open (path , 'r' ) as f :
220
+ return CodeGenConfig .load (f )
221
+
222
+ if config_path :
223
+ return _load (config_path )
224
+
225
+ if os .path .exists (USER_CONFIG_PATH ):
226
+ return _load (USER_CONFIG_PATH )
227
+
228
+ return _load (DEFAULT_CONFIG_PATH )
229
+
230
+
205
231
def main (prog , args ):
206
232
parser = argparse .ArgumentParser (
207
233
prog = prog ,
208
234
formatter_class = argparse .RawTextHelpFormatter )
209
235
210
236
parser .add_argument ("contest_id" ,
211
- help = "contest ID (e.g. arc001)" )
237
+ help = "Contest ID (e.g. arc001)" )
212
238
213
239
parser .add_argument ("--without-login" ,
214
240
action = "store_true" ,
215
- help = "download data without login" )
241
+ help = "Download data without login" )
216
242
217
243
parser .add_argument ("--workspace" ,
218
- help = "path to workspace's root directory. This script will create files"
244
+ help = "Path to workspace's root directory. This script will create files"
219
245
" in {{WORKSPACE}}/{{contest_name}}/{{alphabet}}/ e.g. ./your-workspace/arc001/A/\n "
220
246
"[Default] {}" .format (DEFAULT_WORKSPACE_DIR_PATH ),
221
247
default = DEFAULT_WORKSPACE_DIR_PATH )
222
248
223
249
parser .add_argument ("--lang" ,
224
- help = "programming language of your template code, {}.\n "
250
+ help = "Programming language of your template code, {}.\n "
225
251
.format (" or " .join (SUPPORTED_LANGUAGES )) + "[Default] {}" .format (DEFAULT_LANG ),
226
252
default = DEFAULT_LANG ,
227
253
type = check_lang )
228
254
229
255
parser .add_argument ("--template" ,
230
- help = "{0}{1}" . format ( "file path to your template code\n "
231
- "[Default (C++)] {}\n " .format (
232
- get_default_template_path ('cpp' )),
233
- "[Default (Java)] {}" .format (
234
- get_default_template_path ('java' )))
256
+ help = "File path to your template code\n {0}{1}" . format (
257
+ "[Default (C++)] {}\n " .format (
258
+ get_default_template_path ('cpp' )),
259
+ "[Default (Java)] {}" .format (
260
+ get_default_template_path ('java' )))
235
261
)
236
262
237
263
parser .add_argument ("--replacement" ,
238
- help = "{0}{1}" .format (
239
- "file path to the replacement code created when template generation is failed.\n "
264
+ help = "File path to your config file\n {0}{1}" .format (
240
265
"[Default (C++)] {}\n " .format (get_default_replacement_path ('cpp' )),
241
266
"[Default (Java)] {}" .format (
242
267
get_default_replacement_path ('java' )))
@@ -252,6 +277,13 @@ def main(prog, args):
252
277
help = "Save no session cache to avoid security risk" ,
253
278
default = False )
254
279
280
+ parser .add_argument ("--config" ,
281
+ help = "File path to your config file\n {0}{1}" .format ("[Default (Primary)] {}\n " .format (
282
+ USER_CONFIG_PATH ),
283
+ "[Default (Secondary)] {}\n " .format (
284
+ DEFAULT_CONFIG_PATH ))
285
+ )
286
+
255
287
args = parser .parse_args (args )
256
288
257
289
try :
@@ -281,7 +313,9 @@ def main(prog, args):
281
313
args .replacement if args .replacement is not None else get_default_replacement_path (
282
314
args .lang ),
283
315
args .lang ,
284
- args .parallel )
316
+ args .parallel ,
317
+ get_code_gen_config (args .config )
318
+ )
285
319
286
320
287
321
if __name__ == "__main__" :
0 commit comments