Skip to content

Commit c791a23

Browse files
reduce memory usage of the construct_automata script (#1481)
* remove unneeded loop in `SpliceMutator::mutate` previously we searched for the first and the last difference between exactly the same 2 inputs 3 times in a loop * remove unused struct fields * avoid allocating strings for `Transition`s * avoid allocating `String`s for `Stack`s * avoid allocating Strings for `Element`s * apply some clippy lints * some more clippy lints * simplify regex * remove superflous if condition * remove the Rc<_> in `Element` * small cleanups and regex fix * avoid allocating a vector for the culled pda * bug fix * bug fix * reintroduce the Rc, but make it use the *one* alloced VecDeque this time * slim down dependencies * use Box<[&str]> for storted state stacks this saves us a whopping 8 bytes ;), since we don't have to store the capacity * revert the changes from 9ffa715 fixes a bug * apply clippy lint --------- Co-authored-by: Andrea Fioraldi <[email protected]>
1 parent 4c0e01c commit c791a23

File tree

2 files changed

+66
-63
lines changed

2 files changed

+66
-63
lines changed

utils/gramatron/construct_automata/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ categories = ["development-tools::testing", "emulators", "embedded", "os", "no-s
1515
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1616

1717
[dependencies]
18-
libafl = { path = "../../../libafl" }
18+
libafl = { path = "../../../libafl", default-features = false }
1919
serde_json = "1.0"
2020
regex = "1"
2121
postcard = { version = "1.0", features = ["alloc"], default-features = false } # no_std compatible serde serialization format
2222
clap = { version = "4.0", features = ["derive"] }
23-
log = "0.4.20"
23+
# log = "0.4.20"

utils/gramatron/construct_automata/src/main.rs

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::{
2-
collections::{HashMap, HashSet, VecDeque},
2+
collections::{HashSet, VecDeque},
33
fs,
44
io::{BufReader, Write},
55
path::{Path, PathBuf},
@@ -49,51 +49,52 @@ fn read_grammar_from_file<P: AsRef<Path>>(path: P) -> Value {
4949
}
5050

5151
#[derive(Debug)]
52-
struct Element {
52+
struct Element<'src> {
5353
pub state: usize,
54-
pub items: Rc<VecDeque<String>>,
54+
pub items: Rc<VecDeque<&'src str>>,
5555
}
5656

5757
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
58-
struct Transition {
58+
struct Transition<'src> {
5959
pub source: usize,
6060
pub dest: usize,
61-
pub ss: Vec<String>,
62-
pub terminal: String,
63-
pub is_regex: bool,
64-
pub stack: Rc<VecDeque<String>>,
61+
// pub ss: Vec<String>,
62+
pub terminal: &'src str,
63+
// pub is_regex: bool,
64+
pub stack_len: usize,
6565
}
6666

6767
#[derive(Default)]
68-
struct Stacks {
69-
pub q: HashMap<usize, VecDeque<String>>,
70-
pub s: HashMap<usize, Vec<String>>,
68+
struct Stacks<'src> {
69+
pub q: Vec<Rc<VecDeque<&'src str>>>,
70+
pub s: Vec<Box<[&'src str]>>,
7171
}
7272

73-
fn tokenize(rule: &str) -> (String, Vec<String>, bool) {
73+
fn tokenize(rule: &str) -> (&str, Vec<&str>) {
7474
let re = RE.get_or_init(|| Regex::new(r"([r])*'([\s\S]+)'([\s\S]*)").unwrap());
75+
// let re = RE.get_or_init(|| Regex::new(r"'([\s\S]+)'([\s\S]*)").unwrap());
7576
let cap = re.captures(rule).unwrap();
76-
let is_regex = cap.get(1).is_some();
77-
let terminal = cap.get(2).unwrap().as_str().to_owned();
77+
// let is_regex = cap.get(1).is_some();
78+
let terminal = cap.get(2).unwrap().as_str();
7879
let ss = cap.get(3).map_or(vec![], |m| {
7980
m.as_str()
8081
.split_whitespace()
81-
.map(ToOwned::to_owned)
82+
// .map(ToOwned::to_owned)
8283
.collect()
8384
});
8485
if terminal == "\\n" {
85-
("\n".into(), ss, is_regex)
86+
("\n", ss /*is_regex*/)
8687
} else {
87-
(terminal, ss, is_regex)
88+
(terminal, ss /*is_regex*/)
8889
}
8990
}
9091

91-
fn prepare_transitions(
92-
grammar: &Value,
93-
pda: &mut Vec<Transition>,
94-
state_stacks: &mut Stacks,
92+
fn prepare_transitions<'pda, 'src: 'pda>(
93+
grammar: &'src Value,
94+
pda: &'pda mut Vec<Transition<'src>>,
95+
state_stacks: &mut Stacks<'src>,
9596
state_count: &mut usize,
96-
worklist: &mut VecDeque<Element>,
97+
worklist: &mut VecDeque<Element<'src>>,
9798
element: &Element,
9899
stack_limit: usize,
99100
) {
@@ -102,46 +103,46 @@ fn prepare_transitions(
102103
}
103104

104105
let state = element.state;
105-
let nonterminal = &element.items[0];
106+
let nonterminal = element.items[0];
106107
let rules = grammar[nonterminal].as_array().unwrap();
107108
// let mut i = 0;
108109
'rules_loop: for rule in rules {
109110
let rule = rule.as_str().unwrap();
110-
let (terminal, ss, is_regex) = tokenize(rule);
111+
let (terminal, ss /*_is_regex*/) = tokenize(rule);
111112
let dest = *state_count;
112113

113114
// log::trace!("Rule \"{}\", {} over {}", &rule, i, rules.len());
114115

115116
// Creating a state stack for the new state
116117
let mut state_stack = state_stacks
117118
.q
118-
.get(&state)
119-
.map_or(VecDeque::new(), Clone::clone);
120-
if !state_stack.is_empty() {
121-
state_stack.pop_front();
122-
}
123-
for symbol in ss.iter().rev() {
124-
state_stack.push_front(symbol.clone());
119+
.get(state.wrapping_sub(1))
120+
.map_or(VecDeque::new(), |state_stack| (**state_stack).clone());
121+
122+
state_stack.pop_front();
123+
for symbol in ss.into_iter().rev() {
124+
state_stack.push_front(symbol);
125125
}
126-
let mut state_stack_sorted: Vec<_> = state_stack.iter().cloned().collect();
127-
state_stack_sorted.sort();
126+
let mut state_stack_sorted: Box<_> = state_stack.iter().copied().collect();
127+
state_stack_sorted.sort_unstable();
128128

129129
let mut transition = Transition {
130130
source: state,
131131
dest,
132-
ss,
132+
// ss,
133133
terminal,
134-
is_regex,
135-
stack: Rc::new(state_stack.clone()),
134+
// is_regex,
135+
// stack: Rc::new(state_stack.clone()),
136+
stack_len: state_stack.len(),
136137
};
137138

138139
// Check if a recursive transition state being created, if so make a backward
139140
// edge and don't add anything to the worklist
140-
for (key, val) in &state_stacks.s {
141-
if state_stack_sorted == *val {
142-
transition.dest = *key;
141+
for (dest, stack) in state_stacks.s.iter().enumerate() {
142+
if state_stack_sorted == *stack {
143+
transition.dest = dest + 1;
143144
// i += 1;
144-
pda.push(transition.clone());
145+
pda.push(transition);
145146

146147
// If a recursive transition exercised don't add the same transition as a new
147148
// edge, continue onto the next transitions
@@ -151,18 +152,23 @@ fn prepare_transitions(
151152

152153
// If the generated state has a stack size > stack_limit then that state is abandoned
153154
// and not added to the FSA or the worklist for further expansion
154-
if stack_limit > 0 && transition.stack.len() > stack_limit {
155+
if stack_limit > 0 && transition.stack_len > stack_limit {
155156
// TODO add to unexpanded_rules
156157
continue;
157158
}
158159

160+
let state_stack = Rc::new(state_stack);
161+
159162
// Create transitions for the non-recursive relations and add to the worklist
160163
worklist.push_back(Element {
161164
state: dest,
162-
items: transition.stack.clone(),
165+
items: Rc::clone(&state_stack),
163166
});
164-
state_stacks.q.insert(dest, state_stack);
165-
state_stacks.s.insert(dest, state_stack_sorted);
167+
168+
// since each index corresponds to `state_count - 1`
169+
// index with `dest - 1`
170+
state_stacks.q.push(state_stack);
171+
state_stacks.s.push(state_stack_sorted);
166172
pda.push(transition);
167173

168174
println!("worklist size: {}", worklist.len());
@@ -205,11 +211,11 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
205211
if stack_limit > 0 {
206212
let mut culled_pda = Vec::with_capacity(pda.len());
207213
let mut blocklist = HashSet::new();
208-
//let mut culled_pda_unique = HashSet::new();
214+
// let mut culled_pda_unique = HashSet::new();
209215

210216
for final_state in &finals {
211217
for transition in pda {
212-
if transition.dest == *final_state && transition.stack.len() > 0 {
218+
if transition.dest == *final_state && transition.stack_len > 0 {
213219
blocklist.insert(transition.dest);
214220
} else {
215221
culled_pda.push(transition);
@@ -223,7 +229,9 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
223229
let culled_finals: HashSet<usize> = finals.difference(&blocklist).copied().collect();
224230
assert!(culled_finals.len() == 1);
225231

226-
for transition in &culled_pda {
232+
let culled_pda_len = culled_pda.len();
233+
234+
for transition in culled_pda {
227235
if blocklist.contains(&transition.dest) {
228236
continue;
229237
}
@@ -234,15 +242,11 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
234242
}
235243
memoized[state].push(Trigger {
236244
dest: transition.dest,
237-
term: transition.terminal.clone(),
245+
term: transition.terminal.to_string(),
238246
});
239247

240248
if num_transition % 4096 == 0 {
241-
println!(
242-
"processed {} transitions over {}",
243-
num_transition,
244-
culled_pda.len()
245-
);
249+
println!("processed {num_transition} transitions over {culled_pda_len}",);
246250
}
247251
}
248252

@@ -261,8 +265,8 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
261265
*/
262266

263267
Automaton {
264-
init_state: initial.iter().next().copied().unwrap(),
265-
final_state: culled_finals.iter().next().copied().unwrap(),
268+
init_state: initial.into_iter().next().unwrap(),
269+
final_state: culled_finals.into_iter().next().unwrap(),
266270
pda: memoized,
267271
}
268272
} else {
@@ -275,7 +279,7 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
275279
}
276280
memoized[state].push(Trigger {
277281
dest: transition.dest,
278-
term: transition.terminal.clone(),
282+
term: transition.terminal.to_string(),
279283
});
280284

281285
if num_transition % 4096 == 0 {
@@ -288,8 +292,8 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
288292
}
289293

290294
Automaton {
291-
init_state: initial.iter().next().copied().unwrap(),
292-
final_state: finals.iter().next().copied().unwrap(),
295+
init_state: initial.into_iter().next().unwrap(),
296+
final_state: finals.into_iter().next().unwrap(),
293297
pda: memoized,
294298
}
295299
}
@@ -308,7 +312,7 @@ fn main() {
308312
let mut pda = vec![];
309313

310314
let grammar = read_grammar_from_file(grammar_file);
311-
let start_symbol = grammar["Start"][0].as_str().unwrap().to_owned();
315+
let start_symbol = grammar["Start"][0].as_str().unwrap();
312316
let mut start_vec = VecDeque::new();
313317
start_vec.push_back(start_symbol);
314318
worklist.push_back(Element {
@@ -328,8 +332,7 @@ fn main() {
328332
);
329333
}
330334

331-
state_stacks.q.clear();
332-
state_stacks.s.clear();
335+
drop(state_stacks);
333336

334337
let transformed = postprocess(&pda, stack_limit);
335338
let serialized = postcard::to_allocvec(&transformed).unwrap();

0 commit comments

Comments
 (0)