Skip to content

Commit 34cc603

Browse files
Feature/cleaner input code (#41)
* To generate cleaner input code, add get_length() to Index class and add a logic to simplify CalcNode cleaner code
1 parent b6abd39 commit 34cc603

File tree

33 files changed

+455
-385
lines changed

33 files changed

+455
-385
lines changed

atcodertools/codegen/cpp_code_generator.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ def _loop_header(var: Variable, for_second_index: bool):
1414
index = var.get_first_index()
1515
loop_var = "i"
1616

17-
return "for(int {loop_var} = {start} ; {loop_var} <= {end} ; {loop_var}++){{".format(
17+
return "for(int {loop_var} = 0 ; {loop_var} < {length} ; {loop_var}++){{".format(
1818
loop_var=loop_var,
19-
start=index.get_zero_based_index().min_index,
20-
end=index.get_zero_based_index().max_index
19+
length=index.get_length()
2120
)
2221

2322

@@ -87,15 +86,13 @@ def _generate_declaration(self, var: Variable):
8786
if var.dim_num() == 0:
8887
constructor = ""
8988
elif var.dim_num() == 1:
90-
constructor = "({size}+1)".format(
91-
size=var.get_first_index().get_zero_based_index().max_index)
89+
constructor = "({size})".format(
90+
size=var.get_first_index().get_length())
9291
elif var.dim_num() == 2:
93-
constructor = "({row_size}+1,vector<{type}>({col_size}+1))".format(
92+
constructor = "({row_size}, vector<{type}>({col_size}))".format(
9493
type=self._convert_type(var.type),
95-
row_size=var.get_first_index(
96-
).get_zero_based_index().max_index,
97-
col_size=var.get_second_index(
98-
).get_zero_based_index().max_index
94+
row_size=var.get_first_index().get_length(),
95+
col_size=var.get_second_index().get_length()
9996
)
10097
else:
10198
raise NotImplementedError

atcodertools/codegen/java_code_generator.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,15 @@ def _generate_declaration(self, var: Variable):
2929
if var.dim_num() == 0:
3030
constructor = ""
3131
elif var.dim_num() == 1:
32-
constructor = " = new {type}[(int)({size}+1)]".format(
32+
constructor = " = new {type}[(int)({size})]".format(
3333
type=self._convert_type(var.type),
34-
size=var.get_first_index().get_zero_based_index().max_index
34+
size=var.get_first_index().get_length()
3535
)
3636
elif var.dim_num() == 2:
37-
constructor = " = new {type}[int({row_size}+1)][int({col_size}+1)]".format(
37+
constructor = " = new {type}[int({row_size})][int({col_size})]".format(
3838
type=self._convert_type(var.type),
39-
row_size=var.get_first_index(
40-
).get_zero_based_index().max_index,
41-
col_size=var.get_second_index(
42-
).get_zero_based_index().max_index
39+
row_size=var.get_first_index().get_length(),
40+
col_size=var.get_second_index().get_length()
4341
)
4442
else:
4543
raise NotImplementedError

atcodertools/fmtprediction/calculator.py

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import itertools
2+
import re
3+
from typing import Dict
4+
5+
16
class CalcParseError(Exception):
27
pass
38

@@ -43,7 +48,7 @@ class CalcNode:
4348

4449
def __init__(self, formula=None):
4550
if formula:
46-
root = parse_to_calc_node(formula)
51+
root = _parse(formula)
4752
self.content = root.content
4853
self.lch = root.lch
4954
self.rch = root.rch
@@ -70,29 +75,52 @@ def __ne__(self, other):
7075
return not self.__eq__(other)
7176

7277
def __str__(self, depth=0):
73-
if self.operator is not None:
74-
lv = self.lch.__str__(depth=depth + 1)
75-
rv = self.rch.__str__(depth=depth + 1)
76-
res = "%s%s%s" % (lv, _operator_to_string(self.operator), rv)
77-
if depth > 0 and (self.operator == add or self.operator == sub):
78-
res = "(%s)" % res
79-
return res
80-
elif isinstance(self.content, int):
81-
return str(self.content)
82-
else:
83-
return self.content
78+
opens = [] # Position list of open brackets
79+
cands = []
80+
original_formula = self.to_string_strictly()
81+
for i, c in enumerate(original_formula):
82+
if c == '(':
83+
opens.append(i)
84+
elif c == ')':
85+
assert len(opens) > 0
86+
cands.append((opens[-1], i))
87+
opens.pop()
88+
pass
89+
90+
values_for_identity_check = [3, 14, 15, 92]
91+
92+
def likely_identical(formula: str):
93+
node = CalcNode(formula)
94+
vars = node.get_all_variables()
95+
for combination in itertools.product(values_for_identity_check, repeat=len(vars)):
96+
val_dict = dict(zip(vars, list(combination)))
97+
if self.evaluate(val_dict) != node.evaluate(val_dict):
98+
return False
99+
return True
100+
101+
# Remove parentheses greedy
102+
res_formula = list(original_formula)
103+
for op, cl in cands:
104+
tmp = res_formula.copy()
105+
tmp[op] = ''
106+
tmp[cl] = ''
107+
if likely_identical("".join(tmp)):
108+
res_formula = tmp
109+
simplified_form = "".join(res_formula)
110+
111+
return simplified_form
84112

85113
def get_all_variables(self):
86-
if self.operator is not None:
114+
if self.is_operator_node():
87115
lv = self.lch.get_all_variables()
88116
rv = self.rch.get_all_variables()
89117
return lv + rv
90-
elif isinstance(self.content, int):
118+
elif self.is_constant_node():
91119
return []
92120
else:
93121
return [self.content]
94122

95-
def evaluate(self, variables=None):
123+
def evaluate(self, variables: Dict[str, int] = None):
96124
if variables is None:
97125
variables = {}
98126
if self.is_operator_node():
@@ -103,10 +131,42 @@ def evaluate(self, variables=None):
103131
return int(self.content)
104132
else:
105133
if self.content not in variables:
106-
raise EvaluateError
134+
raise EvaluateError(
135+
"Found an unknown variable '{}'".format(self.content))
107136
else:
108137
return variables[self.content]
109138

139+
def simplify(self):
140+
current_formula = str(self)
141+
142+
# Really stupid heuristics but covers the major case.
143+
while True:
144+
next_formula = re.sub(r"-1\+1$", "", current_formula)
145+
next_formula = re.sub(r"\+0$", "", next_formula)
146+
next_formula = re.sub(r"-0$", "", next_formula)
147+
if next_formula == current_formula:
148+
break
149+
current_formula = next_formula
150+
151+
return CalcNode(current_formula)
152+
153+
def to_string_strictly(self):
154+
if self.is_operator_node():
155+
return "({lch}{op}{rch})".format(
156+
lch=self.lch.to_string_strictly(),
157+
op=_operator_to_string(self.operator),
158+
rch=self.rch.to_string_strictly()
159+
)
160+
else:
161+
return str(self.content)
162+
163+
164+
def _parse(formula: str):
165+
res, pos = _expr(formula + "$", 0) # $ is put as a terminal character
166+
if pos != len(formula):
167+
raise CalcParseError
168+
return res
169+
110170

111171
def _expr(formula, pos):
112172
res, pos = _term(formula, pos)
@@ -179,19 +239,5 @@ def _factor(formula, pos):
179239
raise CalcParseError
180240

181241

182-
def parse_to_calc_node(formula):
183-
"""
184-
入力
185-
formula # str : 式
186-
出力
187-
#CalcNode : 構文木の根ノード
188-
189-
"""
190-
res, pos = _expr(formula + "$", 0) # $は使わないことにする
191-
if pos != len(formula):
192-
raise CalcParseError
193-
return res
194-
195-
196-
if __name__ == '__main__':
197-
print(CalcNode("N-1-1+1000*N*N").evaluate({"N": 10}))
242+
def parse_to_calc_node(formula: str) -> CalcNode:
243+
return CalcNode(formula)

atcodertools/fmtprediction/predict_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def up_cast(old_type, new_type):
7070

7171

7272
def is_float(text):
73-
return re.match("-?\d+\.\d+$", text) is not None
73+
return re.match(r"-?\d+\.\d+$", text) is not None
7474

7575

7676
def is_int(text):
77-
return re.match("-?\d+$", text) is not None
77+
return re.match(r"-?\d+$", text) is not None
7878

7979

8080
def _convert_to_proper_type(value: str):

atcodertools/fmtprediction/tokenize_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _divide_consecutive_vars(text):
4444

4545
def _sanitized_tokens(input_format: str) -> List[str]:
4646
input_format = input_format.replace("\n", " ").replace("…", " ").replace("...", " ").replace(
47-
"..", " ").replace("\ ", " ").replace("}", "} ").replace(" ", " ").replace(", ", ",")
47+
"..", " ").replace("\\ ", " ").replace("}", "} ").replace(" ", " ").replace(", ", ",")
4848
input_format = _divide_consecutive_vars(input_format)
4949
input_format = _normalize_index(input_format)
5050
input_format = input_format.replace("{", "").replace("}", "")

atcodertools/models/analyzer/index.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ def update(self, new_value: str):
1818
self._update_min(new_value)
1919
self._update_max(new_value)
2020

21-
def get_zero_based_index(self):
22-
res = Index()
23-
res.min_index = CalcNode("0")
24-
res.max_index = CalcNode(
25-
"{max_index}-({min_index})".format(
21+
def get_length(self):
22+
assert self.max_index is not None
23+
assert self.min_index is not None
24+
return CalcNode(
25+
"{max_index}-({min_index})+1".format(
2626
max_index=self.max_index,
27-
min_index=self.min_index))
28-
return res
27+
min_index=self.min_index)
28+
).simplify()
2929

3030
def _update_min(self, new_value: str):
3131
if not new_value.isdecimal():

atcodertools/tools/envgen.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,23 +205,24 @@ def main(prog, args):
205205

206206
parser.add_argument("--lang",
207207
help="programming language of your template code, {}.\n"
208-
.format(" or ".join(SUPPORTED_LANGUAGES)) +
209-
"[Default] {}".format(DEFAULT_LANG),
208+
.format(" or ".join(SUPPORTED_LANGUAGES)) + "[Default] {}".format(DEFAULT_LANG),
210209
default=DEFAULT_LANG,
211210
type=check_lang)
212211

213212
parser.add_argument("--template",
214-
help="file path to your template code\n"
215-
"[Default (C++)] {}\n".format(get_default_template_path('cpp')) +
216-
"[Default (Java)] {}".format(
217-
get_default_template_path('java'))
213+
help="{0}{1}".format("file path to your template code\n"
214+
"[Default (C++)] {}\n".format(
215+
get_default_template_path('cpp')),
216+
"[Default (Java)] {}".format(
217+
get_default_template_path('java')))
218218
)
219219

220220
parser.add_argument("--replacement",
221-
help="file path to the replacement code created when template generation is failed.\n"
222-
"[Default (C++)] {}\n".format(get_default_replacement_path('cpp')) +
223-
"[Default (Java)] {}".format(
224-
get_default_replacement_path('java'))
221+
help="{0}{1}".format(
222+
"file path to the replacement code created when template generation is failed.\n"
223+
"[Default (C++)] {}\n".format(get_default_replacement_path('cpp')),
224+
"[Default (Java)] {}".format(
225+
get_default_replacement_path('java')))
225226
)
226227

227228
parser.add_argument("--parallel",

tests/resources/common/download_testcases.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ def mkdirs(path):
4141
for idx, sample in enumerate(content.get_samples()):
4242
with open("{}/ex_{}.txt".format(path, idx + 1), "w") as f:
4343
f.write(sample.get_input())
44-
except SampleDetectionError as e:
44+
except SampleDetectionError:
4545
print(
4646
"failed to parse samples for {} {} -- skipping download".format(contest.get_id(),
4747
problem.get_alphabet()))
48-
except InputFormatDetectionError as e:
48+
except InputFormatDetectionError:
4949
print(
5050
"failed to parse input for {} {} -- skipping download".format(contest.get_id(),
5151
problem.get_alphabet()))
52-
except Exception as e:
52+
except Exception:
5353
print("unknown error for {} {} -- skipping download".format(
5454
contest.get_id(), problem.get_alphabet()))
Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
#include <bits/stdc++.h>
2-
using namespace std;
3-
4-
void solve(string S){
5-
6-
}
7-
8-
int main(){
9-
string S;
10-
cin >> S;
11-
solve(S);
12-
return 0;
13-
}
1+
#include <bits/stdc++.h>
2+
using namespace std;
3+
4+
void solve(string S){
5+
6+
}
7+
8+
int main(){
9+
string S;
10+
cin >> S;
11+
solve(S);
12+
return 0;
13+
}
Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
#include <bits/stdc++.h>
2-
using namespace std;
3-
4-
void solve(string S){
5-
6-
}
7-
8-
int main(){
9-
string S;
10-
cin >> S;
11-
solve(S);
12-
return 0;
13-
}
1+
#include <bits/stdc++.h>
2+
using namespace std;
3+
4+
void solve(string S){
5+
6+
}
7+
8+
int main(){
9+
string S;
10+
cin >> S;
11+
solve(S);
12+
return 0;
13+
}

0 commit comments

Comments
 (0)