Skip to content

Commit 8a710a3

Browse files
Support Problem Constants Injection (#54)
* Add constants prediction logic (mod, yes/no) and its unit test * Support Problem constants injection with new template engine Jinja2 * Raise prediction error instead of returning None and write unit tests * Fix templates * add unittest only_with_no_str * add a unit test with tricky yes/no string case (+ rename test) * add unit test test_nested_embeddings_on_template
1 parent e67a223 commit 8a710a3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+826
-80
lines changed

atcodertools/codegen/code_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
12
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult
23

34
from abc import ABC, abstractmethod
@@ -6,7 +7,7 @@
67
class CodeGenerator(ABC):
78

89
@abstractmethod
9-
def generate_code(self, prediction_result: FormatPredictionResult):
10+
def generate_code(self, prediction_result: FormatPredictionResult, constants: ProblemConstantSet):
1011
raise NotImplementedError
1112

1213

atcodertools/codegen/cpp_code_generator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from atcodertools.codegen.code_gen_config import CodeGenConfig
22
from atcodertools.models.analyzer.analyzed_variable import AnalyzedVariable
33
from atcodertools.models.analyzer.simple_format import Pattern, SingularPattern, ParallelPattern, TwoDimensionalPattern
4+
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
45
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult
56
from atcodertools.models.predictor.variable import Variable
67
from atcodertools.codegen.code_generator import CodeGenerator
@@ -28,14 +29,20 @@ def __init__(self, template: str, config: CodeGenConfig = CodeGenConfig()):
2829
self._prediction_result = None
2930
self._config = config
3031

31-
def generate_code(self, prediction_result: FormatPredictionResult):
32+
def generate_code(self, prediction_result: FormatPredictionResult,
33+
constants: ProblemConstantSet = ProblemConstantSet()):
3234
if prediction_result is None:
3335
raise NoPredictionResultGiven
3436
self._prediction_result = prediction_result
37+
3538
return render(self._template,
3639
formal_arguments=self._formal_arguments(),
3740
actual_arguments=self._actual_arguments(),
38-
input_part=self._input_part())
41+
input_part=self._input_part(),
42+
mod=constants.mod,
43+
yes_str=constants.yes_str,
44+
no_str=constants.no_str,
45+
)
3946

4047
def _input_part(self):
4148
lines = []

atcodertools/codegen/template_engine.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import string
22
import re
3+
import warnings
4+
5+
import jinja2
36

47

58
def _substitute(s, reps):
@@ -17,14 +20,32 @@ def _substitute(s, reps):
1720
sep = ('\n' + m.group(1)) if m.group(1).strip() == '' else '\n'
1821

1922
cr[m.group(2)] = sep.join(reps[m.group(2)])
20-
i += m.end() # continue past last processed replaceable token
23+
i += m.end() # continue past last processed replaceable token
2124
return t.substitute(cr) # we can now substitute
2225

2326

24-
def render(s, **args):
27+
def render(template, **kwargs):
28+
if "${" in template:
29+
# If the template is old, render with the old engine.
30+
# This logic is for backward compatibility
31+
warnings.warn(
32+
"The old template engine with ${} is deprecated. Please use the new Jinja2 template engine.", UserWarning)
33+
34+
return old_render(template, **kwargs)
35+
else:
36+
return render_by_jinja(template, **kwargs)
37+
38+
39+
def old_render(template, **kwargs):
40+
# This render function used to be used before version 1.0.3
2541
new_args = {}
2642

27-
for k, v in args.items():
43+
for k, v in kwargs.items():
2844
new_args[k] = v if isinstance(v, list) else [v]
2945

30-
return _substitute(s, new_args)
46+
return _substitute(template, new_args)
47+
48+
49+
def render_by_jinja(template, **kwargs):
50+
template = jinja2.Template(template)
51+
return template.render(**kwargs) + "\n"

atcodertools/constprediction/__init__.py

Whitespace-only changes.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import logging
2+
import re
3+
from typing import Tuple, Optional
4+
5+
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
6+
from bs4 import BeautifulSoup
7+
8+
from atcodertools.models.problem_content import ProblemContent, InputFormatDetectionError, SampleDetectionError
9+
10+
11+
class YesNoPredictionFailedError(Exception):
12+
pass
13+
14+
15+
class MultipleModCandidatesError(Exception):
16+
17+
def __init__(self, cands):
18+
self.cands = cands
19+
20+
21+
MOD_ANCHORS = ["余り", "あまり", "mod", "割っ", "modulo"]
22+
23+
MOD_STRATEGY_RE_LIST = [
24+
re.compile("([0-9]+).?.?.?で割った"),
25+
re.compile("modu?l?o?[^0-9]?[^0-9]?[^0-9]?([0-9]+)")
26+
]
27+
28+
29+
def is_mod_context(sentence):
30+
for kw in MOD_ANCHORS:
31+
if kw in sentence:
32+
return True
33+
return False
34+
35+
36+
def predict_modulo(html: str) -> Optional[int]:
37+
def normalize(sentence):
38+
return sentence.replace('\\', '').replace("{", "").replace("}", "").replace(",", "").replace(" ", "").replace(
39+
"10^9+7", "1000000007").lower().strip()
40+
41+
soup = BeautifulSoup(html, "html.parser")
42+
sentences = soup.get_text().split("\n")
43+
sentences = [normalize(s) for s in sentences if is_mod_context(s)]
44+
45+
mod_cands = set()
46+
47+
for s in sentences:
48+
for regexp in MOD_STRATEGY_RE_LIST:
49+
m = regexp.search(s)
50+
if m is not None:
51+
extracted_val = int(m.group(1))
52+
mod_cands.add(extracted_val)
53+
54+
if len(mod_cands) == 0:
55+
return None
56+
57+
if len(mod_cands) == 1:
58+
return list(mod_cands)[0]
59+
60+
raise MultipleModCandidatesError(mod_cands)
61+
62+
63+
def predict_yes_no(html: str) -> Tuple[Optional[str], Optional[str]]:
64+
try:
65+
outputs = set()
66+
for sample in ProblemContent.from_html(html).get_samples():
67+
for x in sample.get_output().split("\n"):
68+
outputs.add(x.strip())
69+
except (InputFormatDetectionError, SampleDetectionError) as e:
70+
raise YesNoPredictionFailedError(e)
71+
72+
yes_kws = ["yes", "possible"]
73+
no_kws = ["no", "impossible"]
74+
75+
yes_str = None
76+
no_str = None
77+
for val in outputs:
78+
if val.lower() in yes_kws:
79+
yes_str = val
80+
if val.lower() in no_kws:
81+
no_str = val
82+
83+
return yes_str, no_str
84+
85+
86+
def predict_constants(html: str) -> ProblemConstantSet:
87+
try:
88+
yes_str, no_str = predict_yes_no(html)
89+
except YesNoPredictionFailedError:
90+
yes_str = no_str = None
91+
92+
try:
93+
mod = predict_modulo(html)
94+
except MultipleModCandidatesError as e:
95+
logging.warning("Modulo prediction failed -- "
96+
"two or more candidates {} are detected as modulo values".format(e.cands))
97+
mod = None
98+
99+
return ProblemConstantSet(mod=mod, yes_str=yes_str, no_str=no_str)

atcodertools/fileutils/create_contest_file.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List
33

44
from atcodertools.codegen.code_generator import CodeGenerator
5+
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
56
from atcodertools.models.sample import Sample
67
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult
78

@@ -11,8 +12,11 @@ def _make_text_file(file_path, text):
1112
f.write(text)
1213

1314

14-
def create_code_from_prediction_result(result: FormatPredictionResult, code_generator: CodeGenerator, file_path: str):
15-
_make_text_file(file_path, code_generator.generate_code(result))
15+
def create_code_from(result: FormatPredictionResult,
16+
constants: ProblemConstantSet,
17+
code_generator: CodeGenerator,
18+
file_path: str):
19+
_make_text_file(file_path, code_generator.generate_code(result, constants))
1620

1721

1822
def create_example(example: Sample, in_example_name: str, out_example_name: str):

atcodertools/models/constpred/__init__.py

Whitespace-only changes.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
3+
class ProblemConstantSet:
4+
def __init__(self,
5+
mod: int = None,
6+
yes_str: str = None,
7+
no_str: str = None,
8+
):
9+
self.mod = mod
10+
self.yes_str = yes_str
11+
self.no_str = no_str

atcodertools/models/problem_content.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import List, Tuple, Optional
22

33
from bs4 import BeautifulSoup
44

@@ -36,13 +36,17 @@ class InputFormatDetectionError(Exception):
3636

3737
class ProblemContent:
3838

39-
def __init__(self, input_format_text: str = None, samples: List[Sample] = None):
39+
def __init__(self, input_format_text: Optional[str] = None,
40+
samples: Optional[List[Sample]] = None,
41+
original_html: Optional[str] = None,
42+
):
4043
self.samples = samples
4144
self.input_format_text = input_format_text
45+
self.original_html = original_html
4246

4347
@classmethod
44-
def from_html(cls, html: str = None):
45-
res = ProblemContent()
48+
def from_html(cls, html: str):
49+
res = ProblemContent(original_html=html)
4650
soup = BeautifulSoup(html, "html.parser")
4751
res.input_format_text, res.samples = res._extract_input_format_and_samples(
4852
soup)
@@ -81,13 +85,17 @@ def _extract_input_format_and_samples(soup) -> Tuple[str, List[Sample]]:
8185
if len(input_tags) != len(output_tags):
8286
raise SampleDetectionError
8387

84-
res = [Sample(normalize(in_tag.text), normalize(out_tag.text))
85-
for in_tag, out_tag in zip(input_tags, output_tags)]
88+
try:
89+
res = [Sample(normalize(in_tag.text), normalize(out_tag.text))
90+
for in_tag, out_tag in zip(input_tags, output_tags)]
91+
92+
if input_format_tag is None:
93+
raise InputFormatDetectionError
8694

87-
if input_format_tag is None:
95+
input_format_text = normalize(input_format_tag.text)
96+
except AttributeError:
8897
raise InputFormatDetectionError
8998

90-
input_format_text = normalize(input_format_tag.text)
9199
return input_format_text, res
92100

93101
@staticmethod

atcodertools/tools/envgen.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from atcodertools.codegen.code_gen_config import CodeGenConfig
1212
from atcodertools.codegen.cpp_code_generator import CppCodeGenerator
1313
from atcodertools.codegen.java_code_generator import JavaCodeGenerator
14-
from atcodertools.fileutils.create_contest_file import create_examples, create_code_from_prediction_result
14+
from atcodertools.constprediction.constants_prediction import predict_constants
15+
from atcodertools.fileutils.create_contest_file import create_examples, \
16+
create_code_from
1517
from atcodertools.models.problem_content import InputFormatDetectionError, SampleDetectionError
1618
from atcodertools.client.atcoder import AtCoderClient, Contest, LoginError
1719
from atcodertools.fmtprediction.predict_format import FormatPredictor, NoPredictionResultError, \
@@ -101,20 +103,22 @@ def emit_info(text):
101103
new_path))
102104

103105
try:
104-
result = FormatPredictor().predict(content)
105-
106-
with open(template_code_path, "r") as f:
107-
template = f.read()
108-
109106
if lang == "cpp":
110107
gen_class = CppCodeGenerator
111108
elif lang == "java":
112109
gen_class = JavaCodeGenerator
113110
else:
114111
raise NotImplementedError("only supporting cpp and java")
115112

116-
create_code_from_prediction_result(
113+
with open(template_code_path, "r") as f:
114+
template = f.read()
115+
116+
result = FormatPredictor().predict(content)
117+
constants = predict_constants(content.original_html)
118+
119+
create_code_from(
117120
result,
121+
constants,
118122
gen_class(template, config),
119123
code_file_path)
120124
emit_info(

0 commit comments

Comments
 (0)