Skip to content

Commit c164bb4

Browse files
committed
Add code gen configuration
1 parent a2dd3b3 commit c164bb4

File tree

10 files changed

+96
-16
lines changed

10 files changed

+96
-16
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import toml
2+
3+
INDENT_TYPE_SPACE = 'space'
4+
INDENT_TYPE_TAB = 'tab'
5+
6+
7+
def _verify_indent_type(indent_type: str):
8+
# indent_type must be 'space' or 'tab'
9+
assert indent_type in [INDENT_TYPE_SPACE, INDENT_TYPE_TAB]
10+
return indent_type
11+
12+
13+
class CodeGenConfig:
14+
def __init__(self,
15+
indent_type: str = INDENT_TYPE_SPACE,
16+
indent_width: int = 4,
17+
):
18+
self.indent_type = _verify_indent_type(indent_type)
19+
self.indent_width = indent_width
20+
21+
def indent(self, depth):
22+
if self.indent_type == INDENT_TYPE_SPACE:
23+
return " " * self.indent_width * depth
24+
return "\t" * self.indent_width * depth
25+
26+
@classmethod
27+
def load(cls, config_file_path):
28+
with open(config_file_path) as f:
29+
kwargs = toml.load(f).get("codegen")
30+
return CodeGenConfig(**kwargs)

atcodertools/codegen/cpp_code_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from atcodertools.codegen.code_gen_config import CodeGenConfig
12
from atcodertools.models.analyzer.analyzed_variable import AnalyzedVariable
23
from atcodertools.models.analyzer.simple_format import Pattern, SingularPattern, ParallelPattern, TwoDimensionalPattern
34
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult
@@ -22,9 +23,10 @@ def _loop_header(var: Variable, for_second_index: bool):
2223

2324
class CppCodeGenerator(CodeGenerator):
2425

25-
def __init__(self, template: str):
26+
def __init__(self, template: str, config: CodeGenConfig = CodeGenConfig()):
2627
self._template = template
2728
self._prediction_result = None
29+
self._config = config
2830

2931
def generate_code(self, prediction_result: FormatPredictionResult):
3032
if prediction_result is None:
@@ -157,7 +159,7 @@ def _render_pattern(self, pattern: Pattern):
157159
return lines
158160

159161
def _indent(self, depth):
160-
return " " * depth
162+
return self._config.indent(depth)
161163

162164

163165
class NoPredictionResultGiven(Exception):

atcodertools/release_management/version_check.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ def store_version_cache(version):
4141
f.write("{} {}".format(version, time.time()))
4242

4343

44-
def get_latest_version(user_cache=True):
44+
def get_latest_version(use_cache=True):
4545
try:
46-
if user_cache:
46+
if use_cache:
4747
cached_version = _get_latest_version_cache()
4848
if cached_version:
4949
return cached_version
5050

5151
version = _fetch_latest_version()
5252

53-
if user_cache:
53+
if use_cache:
5454
store_version_cache(version)
5555

5656
return version
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[codegen]
2+
indent_type = 'space' # 'tab' or 'space'
3+
indent_width = 4

atcodertools/tools/envgen.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from time import sleep
99
from typing import Tuple
1010

11+
from atcodertools.codegen.code_gen_config import CodeGenConfig
1112
from atcodertools.codegen.cpp_code_generator import CppCodeGenerator
1213
from atcodertools.codegen.java_code_generator import JavaCodeGenerator
1314
from atcodertools.fileutils.create_contest_file import create_examples, create_code_from_prediction_result
@@ -43,7 +44,9 @@ def prepare_procedure(atcoder_client: AtCoderClient,
4344
workspace_root_path: str,
4445
template_code_path: str,
4546
replacement_code_path: str,
46-
lang: str):
47+
lang: str,
48+
config: CodeGenConfig,
49+
):
4750
pid = problem.get_alphabet()
4851
workspace_dir_path = os.path.join(
4952
workspace_root_path,
@@ -112,7 +115,7 @@ def emit_info(text):
112115

113116
create_code_from_prediction_result(
114117
result,
115-
gen_class(template),
118+
gen_class(template, config),
116119
code_file_path)
117120
emit_info(
118121
"Prediction succeeded -- Saved auto-generated code to '{}'".format(code_file_path))
@@ -141,11 +144,11 @@ def emit_info(text):
141144
emit_info("Saved metadata to {}".format(metadata_path))
142145

143146

144-
def func(argv: Tuple[AtCoderClient, Problem, str, str, str, str]):
145-
atcoder_client, problem, workspace_root_path, template_code_path, replacement_code_path, lang = argv
147+
def func(argv: Tuple[AtCoderClient, Problem, str, str, str, str, CodeGenConfig]):
148+
atcoder_client, problem, workspace_root_path, template_code_path, replacement_code_path, lang, config = argv
146149
prepare_procedure(
147150
atcoder_client, problem, workspace_root_path, template_code_path,
148-
replacement_code_path, lang)
151+
replacement_code_path, lang, config)
149152

150153

151154
def prepare_workspace(atcoder_client: AtCoderClient,
@@ -154,7 +157,9 @@ def prepare_workspace(atcoder_client: AtCoderClient,
154157
template_code_path: str,
155158
replacement_code_path: str,
156159
lang: str,
157-
parallel: bool):
160+
parallel: bool,
161+
config: CodeGenConfig,
162+
):
158163
retry_duration = 1.5
159164
while True:
160165
problem_list = atcoder_client.download_problem_list(
@@ -165,7 +170,7 @@ def prepare_workspace(atcoder_client: AtCoderClient,
165170
logging.warning(
166171
"Failed to fetch. Will retry in {} seconds".format(retry_duration))
167172

168-
tasks = [(atcoder_client, problem, workspace_root_path, template_code_path, replacement_code_path, lang) for
173+
tasks = [(atcoder_client, problem, workspace_root_path, template_code_path, replacement_code_path, lang, config) for
169174
problem in problem_list]
170175
if parallel:
171176
thread_pool = Pool(processes=cpu_count())
@@ -202,6 +207,25 @@ def check_lang(lang: str):
202207
return lang
203208

204209

210+
PRIMARY_DEFAULT_CONFIG_PATH = os.path.join(
211+
expanduser("~"), ".atcodertools.toml")
212+
SECONDARY_DEFAULT_CONFIG_PATH = os.path.abspath(
213+
os.path.join(script_dir_path, "./atcodertools-default.toml"))
214+
215+
216+
def get_code_gen_config(config_path: str = None):
217+
if config_path:
218+
logging.info("Going to load {} as config".format(config_path))
219+
return CodeGenConfig.load(config_path)
220+
if os.path.exists(PRIMARY_DEFAULT_CONFIG_PATH):
221+
logging.info("Going to load {} as config".format(
222+
PRIMARY_DEFAULT_CONFIG_PATH))
223+
return CodeGenConfig.load(PRIMARY_DEFAULT_CONFIG_PATH)
224+
logging.info("Going to load {} as config".format(
225+
SECONDARY_DEFAULT_CONFIG_PATH))
226+
return CodeGenConfig.load(SECONDARY_DEFAULT_CONFIG_PATH)
227+
228+
205229
def main(prog, args):
206230
parser = argparse.ArgumentParser(
207231
prog=prog,
@@ -252,6 +276,12 @@ def main(prog, args):
252276
help="Save no session cache to avoid security risk",
253277
default=False)
254278

279+
parser.add_argument("--config",
280+
help="{0}{1}{2}".format("file path to your config file\n",
281+
"[Default (Primary)] {}\n".format(
282+
PRIMARY_DEFAULT_CONFIG_PATH),
283+
"[Default (Secondary)] {}\n".format(SECONDARY_DEFAULT_CONFIG_PATH)))
284+
255285
args = parser.parse_args(args)
256286

257287
try:
@@ -281,7 +311,9 @@ def main(prog, args):
281311
args.replacement if args.replacement is not None else get_default_replacement_path(
282312
args.lang),
283313
args.lang,
284-
args.parallel)
314+
args.parallel,
315+
get_code_gen_config(args.config)
316+
)
285317

286318

287319
if __name__ == "__main__":

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
beautifulsoup4
22
requests
3-
colorama
3+
colorama
4+
toml
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[codegen]
2+
indent_width = 8

tests/test_codegen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest
44
import os
55

6+
from atcodertools.codegen.code_gen_config import CodeGenConfig, INDENT_TYPE_SPACE
67
from atcodertools.codegen.code_generator import CodeGenerator
78
from atcodertools.codegen.java_code_generator import JavaCodeGenerator
89
from atcodertools.codegen.cpp_code_generator import CppCodeGenerator
@@ -81,6 +82,12 @@ def get_generator(self, lang: str) -> CodeGenerator:
8182
with open(template_file, 'r') as f:
8283
return self.lang_to_code_generator[lang](f.read())
8384

85+
def test_load_code_gen_config(self):
86+
toml_path = os.path.join(RESOURCE_DIR, "atcodertools-test.toml")
87+
config = CodeGenConfig.load(toml_path)
88+
self.assertEqual(8, config.indent_width)
89+
self.assertEqual(INDENT_TYPE_SPACE, config.indent_type)
90+
8491

8592
if __name__ == "__main__":
8693
unittest.main()

tests/test_envgen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from os.path import relpath
77

88
from atcodertools.client.atcoder import AtCoderClient
9+
from atcodertools.codegen.code_gen_config import CodeGenConfig
910
from atcodertools.tools.envgen import prepare_workspace, main
1011

1112
RESOURCE_DIR = os.path.join(
@@ -59,7 +60,9 @@ def test_backup(self):
5960
TEMPLATE_PATH,
6061
REPLACEMENT_PATH,
6162
"cpp",
62-
False)
63+
False,
64+
CodeGenConfig(),
65+
)
6366
self.assertDirectoriesEqual(answer_data_dir_path, self.temp_dir)
6467

6568
def assertDirectoriesEqual(self, expected_dir_path, dir_path):

tests/test_version_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class TestTester(unittest.TestCase):
1212
def test_get_latest_version_with_no_error(self):
13-
get_latest_version(user_cache=False)
13+
get_latest_version(use_cache=False)
1414

1515

1616
if __name__ == '__main__':

0 commit comments

Comments
 (0)