Skip to content

Commit 7f56f69

Browse files
fukatanikyuridenamida
authored andcommitted
generate rust code. (#86)
* WIP generate rust code. * add test * flake8 * Update readme * rust template: avoid stack over flow. * fix comment * change rust extension * support rust submit * refactoring * fix indent level
1 parent 82213e6 commit 7f56f69

File tree

14 files changed

+754
-6
lines changed

14 files changed

+754
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Python 3.5 以降で動作する [AtCoder](http://atcoder.jp/) からサンプ
1010
- AtCoderへのログイン,入出力例データなどの抽出
1111
- 枝刈り探索による高精度・高速な入力フォーマット解析 (ARC、ABC、AGCについては約9割ほど)
1212
- 問題文中に含まれるMOD値やYES/NO文字列等の定数値抽出
13-
- 入力フォーマット解析結果や抽出した定数値を用いたテンプレートコードの自動生成(C++, Java)
13+
- 入力フォーマット解析結果や抽出した定数値を用いたテンプレートコードの自動生成(C++, Java, Rust)
1414
- カスタムテンプレートに対応
1515
- 他言語対応のためのコントリビューション(≒中間形式からコードに変換する部分のPR)を募集中です!
1616
- コード提出機能
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from atcodertools.codegen.models.code_gen_args import CodeGenArgs
2+
from atcodertools.codegen.code_generators.cpp import CppCodeGenerator
3+
from atcodertools.codegen.template_engine import render
4+
from atcodertools.fmtprediction.models.format import Pattern, SingularPattern, ParallelPattern, TwoDimensionalPattern
5+
from atcodertools.fmtprediction.models.type import Type
6+
from atcodertools.fmtprediction.models.variable import Variable
7+
8+
9+
# RustCodeGenerator uses part of CppCodeGenerator just for less code clone.
10+
11+
12+
def _loop_header(var: Variable, for_second_index: bool):
13+
if for_second_index:
14+
index = var.second_index
15+
loop_var = "j"
16+
else:
17+
index = var.first_index
18+
loop_var = "i"
19+
20+
return "for {loop_var} in 0..({length}) as usize {{".format(
21+
loop_var=loop_var,
22+
length=index.get_length()
23+
)
24+
25+
26+
class RustCodeGenerator(CppCodeGenerator):
27+
28+
def _input_part(self):
29+
lines = ["let con = read_string();",
30+
"let mut scanner = Scanner::new(&con);"]
31+
for pattern in self._format.sequence:
32+
lines += self._render_pattern(pattern)
33+
return "\n{indent}".format(indent=self._indent(1)).join(lines)
34+
35+
def _convert_type(self, type_: Type) -> str:
36+
if type_ == Type.float:
37+
return "f64"
38+
elif type_ == Type.int:
39+
return "i64"
40+
elif type_ == Type.str:
41+
return "String"
42+
else:
43+
raise NotImplementedError
44+
45+
def _get_declaration_type(self, var: Variable):
46+
if var.dim_num() == 0:
47+
template = "{type}"
48+
elif var.dim_num() == 1:
49+
template = "Vec<{type}>"
50+
elif var.dim_num() == 2:
51+
template = "Vec<Vec<{type}>>"
52+
else:
53+
raise NotImplementedError
54+
return template.format(type=self._convert_type(var.type))
55+
56+
def _actual_arguments(self) -> str:
57+
return ", ".join([v.name for v in self._format.all_vars()])
58+
59+
def _formal_arguments(self):
60+
"""
61+
:return the string form of formal arguments e.g. "N: i64, K: i64, a: Vec<i64>"
62+
"""
63+
return ", ".join([
64+
"{name}: {decl_type}".format(
65+
decl_type=self._get_declaration_type(v),
66+
name=v.name)
67+
for v in self._format.all_vars()
68+
])
69+
70+
def _generate_declaration(self, var: Variable):
71+
if var.dim_num() == 0:
72+
constructor = ""
73+
elif var.dim_num() == 1:
74+
if var.type == Type.str:
75+
constructor = " = vec![String::new(); ({size}) as usize]".format(
76+
type=self._convert_type(var.type),
77+
size=var.first_index.get_length()
78+
)
79+
else:
80+
constructor = " = vec![0{type}; ({size}) as usize]".format(
81+
type=self._convert_type(var.type),
82+
size=var.first_index.get_length()
83+
)
84+
elif var.dim_num() == 2:
85+
if var.type == Type.str:
86+
constructor = " = vec![vec![String::new(); ({col_size}) as usize]; ({row_size}) as usize]".format(
87+
type=self._convert_type(var.type),
88+
row_size=var.first_index.get_length(),
89+
col_size=var.second_index.get_length()
90+
)
91+
else:
92+
constructor = " = vec![vec![0{type}; ({col_size}) as usize]; ({row_size}) as usize]".format(
93+
type=self._convert_type(var.type),
94+
row_size=var.first_index.get_length(),
95+
col_size=var.second_index.get_length()
96+
)
97+
else:
98+
raise NotImplementedError
99+
100+
line = "let mut {name}: {decl_type}{constructor};".format(
101+
name=var.name,
102+
decl_type=self._get_declaration_type(var),
103+
constructor=constructor
104+
)
105+
return line
106+
107+
def _input_code_for_var(self, var: Variable) -> str:
108+
name = self._get_var_name(var)
109+
return '{name} = scanner.next();'.format(name=name)
110+
111+
def _render_pattern(self, pattern: Pattern):
112+
lines = []
113+
for var in pattern.all_vars():
114+
lines.append(self._generate_declaration(var))
115+
116+
representative_var = pattern.all_vars()[0]
117+
if isinstance(pattern, SingularPattern):
118+
lines.append(self._input_code_for_var(representative_var))
119+
elif isinstance(pattern, ParallelPattern):
120+
lines.append(_loop_header(representative_var, False))
121+
for var in pattern.all_vars():
122+
lines.append("{indent}{line}".format(indent=self._indent(1),
123+
line=self._input_code_for_var(var)))
124+
lines.append("}")
125+
elif isinstance(pattern, TwoDimensionalPattern):
126+
lines.append(_loop_header(representative_var, False))
127+
lines.append(
128+
"{indent}{line}".format(indent=self._indent(1), line=_loop_header(representative_var, True)))
129+
for var in pattern.all_vars():
130+
lines.append("{indent}{line}".format(indent=self._indent(2),
131+
line=self._input_code_for_var(var)))
132+
lines.append("{indent}}}".format(indent=self._indent(1)))
133+
lines.append("}")
134+
else:
135+
raise NotImplementedError
136+
137+
return lines
138+
139+
140+
def main(args: CodeGenArgs) -> str:
141+
code_parameters = RustCodeGenerator(
142+
args.format, args.config).generate_parameters()
143+
return render(
144+
args.template,
145+
mod=args.constants.mod,
146+
yes_str=args.constants.yes_str,
147+
no_str=args.constants.no_str,
148+
**code_parameters
149+
)

atcodertools/codegen/code_style_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class CodeStyleConfigInitError(Exception):
1414

1515

1616
DEFAULT_WORKSPACE_DIR_PATH = os.path.join(expanduser("~"), "atcoder-workspace")
17-
SUPPORTED_LANGUAGES = ["cpp", "java"]
17+
SUPPORTED_LANGUAGES = ["cpp", "java", "rust"]
1818

1919

2020
class CodeStyleConfig:
@@ -68,9 +68,12 @@ def __init__(self,
6868
if lang == "cpp":
6969
from atcodertools.codegen.code_generators import cpp
7070
self.code_generator = cpp.main
71-
else:
71+
elif lang == "java":
7272
from atcodertools.codegen.code_generators import java
7373
self.code_generator = java.main
74+
else:
75+
from atcodertools.codegen.code_generators import rust
76+
self.code_generator = rust.main
7477

7578
self.template_file = normalize_path(
7679
template_file or get_default_template_path(lang))

atcodertools/tools/envgen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class BannedFileDetectedError(Exception):
3939

4040

4141
def extension(lang: str):
42+
if lang == 'rust':
43+
return 'rs'
4244
return lang
4345

4446

@@ -259,7 +261,9 @@ def main(prog, args):
259261
"[Default (C++)] {}\n".format(
260262
get_default_template_path('cpp')),
261263
"[Default (Java)] {}".format(
262-
get_default_template_path('java')))
264+
get_default_template_path('java')),
265+
"[Default (Rust)] {}".format(
266+
get_default_template_path('rust'))),
263267
)
264268

265269
# Deleted functionality

atcodertools/tools/submit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def infer_detailed_lang(lang: str):
1919
return "Java8 (OpenJDK 1.8.0)"
2020
if lang == "cpp":
2121
return "C++14 (GCC 5.4.1)"
22+
if lang == "rust":
23+
return "Rust (1.15.1)"
2224
raise NotImplementedError
2325

2426

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use io::*;
2+
use std::*;
3+
4+
{% if mod %}
5+
const MOD: i64 = {{ mod }};
6+
{% endif %}
7+
{% if yes_str %}
8+
const YES: String = "{{ yes_str }}";
9+
{% endif %}
10+
{% if no_str %}
11+
const NO: String = "{{ no_str }}";
12+
{% endif %}
13+
{% if prediction_success %}
14+
fn solve({{ formal_arguments }}) {
15+
16+
}
17+
{% endif %}
18+
19+
fn main() {
20+
{% if prediction_success %}
21+
{{input_part}}
22+
// In order to avoid potential stack overflow, spawn a new thread.
23+
let stack_size = 104_857_600; // 100 MB
24+
let thd = std::thread::Builder::new().stack_size(stack_size);
25+
thd.spawn(move || solve({{ actual_arguments }})).unwrap().join().unwrap();
26+
{% else %}
27+
// Failed to predict input format
28+
{% endif %}
29+
}
30+
31+
pub mod io {
32+
use std;
33+
use std::str::FromStr;
34+
35+
pub struct Scanner<'a> {
36+
iter: std::str::SplitWhitespace<'a>,
37+
}
38+
39+
impl<'a> Scanner<'a> {
40+
pub fn new(s: &'a str) -> Scanner<'a> {
41+
Scanner {
42+
iter: s.split_whitespace(),
43+
}
44+
}
45+
46+
pub fn next<T: FromStr>(&mut self) -> T {
47+
let s = self.iter.next().unwrap();
48+
if let Ok(v) = s.parse::<T>() {
49+
v
50+
} else {
51+
panic!("Parse error")
52+
}
53+
}
54+
55+
pub fn next_vec_len<T: FromStr>(&mut self) -> Vec<T> {
56+
let n: usize = self.next();
57+
self.next_vec(n)
58+
}
59+
60+
pub fn next_vec<T: FromStr>(&mut self, n: usize) -> Vec<T> {
61+
(0..n).map(|_| self.next()).collect()
62+
}
63+
}
64+
65+
pub fn read_string() -> String {
66+
use std::io::Read;
67+
68+
let mut s = String::new();
69+
std::io::stdin().read_to_string(&mut s).unwrap();
70+
s
71+
}
72+
73+
pub fn read_line() -> String {
74+
let mut s = String::new();
75+
std::io::stdin().read_line(&mut s).unwrap();
76+
s.trim_right().to_owned()
77+
}
78+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use std::*;
2+
3+
fn solve(${formal_arguments}) {
4+
5+
}
6+
7+
fn main() {
8+
${input_part}
9+
solve(${actual_arguments});
10+
}
11+
12+
pub mod io {
13+
use std;
14+
use std::str::FromStr;
15+
16+
pub struct Scanner<'a> {
17+
iter: std::str::SplitWhitespace<'a>,
18+
}
19+
20+
impl<'a> Scanner<'a> {
21+
pub fn new(s: &'a str) -> Scanner<'a> {
22+
Scanner {
23+
iter: s.split_whitespace(),
24+
}
25+
}
26+
27+
pub fn next<T: FromStr>(&mut self) -> T {
28+
let s = self.iter.next().unwrap();
29+
if let Ok(v) = s.parse::<T>() {
30+
v
31+
} else {
32+
panic!("Parse error")
33+
}
34+
}
35+
36+
pub fn next_vec_len<T: FromStr>(&mut self) -> Vec<T> {
37+
let n: usize = self.next();
38+
self.next_vec(n)
39+
}
40+
41+
pub fn next_vec<T: FromStr>(&mut self, n: usize) -> Vec<T> {
42+
(0..n).map(|_| self.next()).collect()
43+
}
44+
}
45+
46+
pub fn read_string() -> String {
47+
use std::io::Read;
48+
49+
let mut s = String::new();
50+
std::io::stdin().read_to_string(&mut s).unwrap();
51+
s
52+
}
53+
54+
pub fn read_line() -> String {
55+
let mut s = String::new();
56+
std::io::stdin().read_line(&mut s).unwrap();
57+
s.trim_right().to_owned()
58+
}
59+
}

0 commit comments

Comments
 (0)