8
8
from time import sleep
9
9
from typing import Tuple
10
10
11
+ from atcodertools .codegen .code_gen_config import CodeGenConfig
11
12
from atcodertools .codegen .cpp_code_generator import CppCodeGenerator
12
13
from atcodertools .codegen .java_code_generator import JavaCodeGenerator
13
14
from atcodertools .fileutils .create_contest_file import create_examples , create_code_from_prediction_result
@@ -43,7 +44,9 @@ def prepare_procedure(atcoder_client: AtCoderClient,
43
44
workspace_root_path : str ,
44
45
template_code_path : str ,
45
46
replacement_code_path : str ,
46
- lang : str ):
47
+ lang : str ,
48
+ config : CodeGenConfig ,
49
+ ):
47
50
pid = problem .get_alphabet ()
48
51
workspace_dir_path = os .path .join (
49
52
workspace_root_path ,
@@ -112,7 +115,7 @@ def emit_info(text):
112
115
113
116
create_code_from_prediction_result (
114
117
result ,
115
- gen_class (template ),
118
+ gen_class (template , config ),
116
119
code_file_path )
117
120
emit_info (
118
121
"Prediction succeeded -- Saved auto-generated code to '{}'" .format (code_file_path ))
@@ -141,11 +144,11 @@ def emit_info(text):
141
144
emit_info ("Saved metadata to {}" .format (metadata_path ))
142
145
143
146
144
- def func (argv : Tuple [AtCoderClient , Problem , str , str , str , str ]):
145
- atcoder_client , problem , workspace_root_path , template_code_path , replacement_code_path , lang = argv
147
+ def func (argv : Tuple [AtCoderClient , Problem , str , str , str , str , CodeGenConfig ]):
148
+ atcoder_client , problem , workspace_root_path , template_code_path , replacement_code_path , lang , config = argv
146
149
prepare_procedure (
147
150
atcoder_client , problem , workspace_root_path , template_code_path ,
148
- replacement_code_path , lang )
151
+ replacement_code_path , lang , config )
149
152
150
153
151
154
def prepare_workspace (atcoder_client : AtCoderClient ,
@@ -154,7 +157,9 @@ def prepare_workspace(atcoder_client: AtCoderClient,
154
157
template_code_path : str ,
155
158
replacement_code_path : str ,
156
159
lang : str ,
157
- parallel : bool ):
160
+ parallel : bool ,
161
+ config : CodeGenConfig ,
162
+ ):
158
163
retry_duration = 1.5
159
164
while True :
160
165
problem_list = atcoder_client .download_problem_list (
@@ -165,7 +170,7 @@ def prepare_workspace(atcoder_client: AtCoderClient,
165
170
logging .warning (
166
171
"Failed to fetch. Will retry in {} seconds" .format (retry_duration ))
167
172
168
- tasks = [(atcoder_client , problem , workspace_root_path , template_code_path , replacement_code_path , lang ) for
173
+ tasks = [(atcoder_client , problem , workspace_root_path , template_code_path , replacement_code_path , lang , config ) for
169
174
problem in problem_list ]
170
175
if parallel :
171
176
thread_pool = Pool (processes = cpu_count ())
@@ -202,6 +207,25 @@ def check_lang(lang: str):
202
207
return lang
203
208
204
209
210
+ PRIMARY_DEFAULT_CONFIG_PATH = os .path .join (
211
+ expanduser ("~" ), ".atcodertools.toml" )
212
+ SECONDARY_DEFAULT_CONFIG_PATH = os .path .abspath (
213
+ os .path .join (script_dir_path , "./atcodertools-default.toml" ))
214
+
215
+
216
+ def get_code_gen_config (config_path : str = None ):
217
+ if config_path :
218
+ logging .info ("Going to load {} as config" .format (config_path ))
219
+ return CodeGenConfig .load (config_path )
220
+ if os .path .exists (PRIMARY_DEFAULT_CONFIG_PATH ):
221
+ logging .info ("Going to load {} as config" .format (
222
+ PRIMARY_DEFAULT_CONFIG_PATH ))
223
+ return CodeGenConfig .load (PRIMARY_DEFAULT_CONFIG_PATH )
224
+ logging .info ("Going to load {} as config" .format (
225
+ SECONDARY_DEFAULT_CONFIG_PATH ))
226
+ return CodeGenConfig .load (SECONDARY_DEFAULT_CONFIG_PATH )
227
+
228
+
205
229
def main (prog , args ):
206
230
parser = argparse .ArgumentParser (
207
231
prog = prog ,
@@ -252,6 +276,12 @@ def main(prog, args):
252
276
help = "Save no session cache to avoid security risk" ,
253
277
default = False )
254
278
279
+ parser .add_argument ("--config" ,
280
+ help = "{0}{1}{2}" .format ("file path to your config file\n " ,
281
+ "[Default (Primary)] {}\n " .format (
282
+ PRIMARY_DEFAULT_CONFIG_PATH ),
283
+ "[Default (Secondary)] {}\n " .format (SECONDARY_DEFAULT_CONFIG_PATH )))
284
+
255
285
args = parser .parse_args (args )
256
286
257
287
try :
@@ -281,7 +311,9 @@ def main(prog, args):
281
311
args .replacement if args .replacement is not None else get_default_replacement_path (
282
312
args .lang ),
283
313
args .lang ,
284
- args .parallel )
314
+ args .parallel ,
315
+ get_code_gen_config (args .config )
316
+ )
285
317
286
318
287
319
if __name__ == "__main__" :
0 commit comments