Skip to content

Commit 7aad44b

Browse files
committed
perf(*): cache regex predicates with rc using router attribute
KAG-3182
1 parent aa56399 commit 7aad44b

File tree

4 files changed

+116
-51
lines changed

4 files changed

+116
-51
lines changed

src/ast.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::schema::Schema;
22
use cidr::IpCidr;
33
use regex::Regex;
4-
use std::net::IpAddr;
4+
use std::{net::IpAddr, rc::Rc};
55

66
#[cfg(feature = "serde")]
77
use serde::{Deserialize, Serialize};
@@ -53,7 +53,7 @@ pub enum Value {
5353
IpAddr(IpAddr),
5454
Int(i64),
5555
#[cfg_attr(feature = "serde", serde(with = "serde_regex"))]
56-
Regex(Regex),
56+
Regex(Rc<Regex>),
5757
}
5858

5959
impl PartialEq for Value {
@@ -137,7 +137,7 @@ pub struct Predicate {
137137
mod tests {
138138
use super::*;
139139
use crate::parser::parse;
140-
use std::fmt;
140+
use std::{collections::HashMap, fmt};
141141

142142
impl fmt::Display for Expression {
143143
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -240,6 +240,7 @@ mod tests {
240240

241241
#[test]
242242
fn expr_op_and_prec() {
243+
let mut regex_cache = HashMap::new();
243244
let tests = vec![
244245
("a > 0", "(a > 0)"),
245246
("a in \"abc\"", "(a in \"abc\")"),
@@ -271,13 +272,14 @@ mod tests {
271272
),
272273
];
273274
for (input, expected) in tests {
274-
let result = parse(input).unwrap();
275+
let result = parse(input, &mut regex_cache).unwrap();
275276
assert_eq!(result.to_string(), expected);
276277
}
277278
}
278279

279280
#[test]
280281
fn expr_var_name_and_ip() {
282+
let mut regex_cache = HashMap::new();
281283
let tests = vec![
282284
// ipv4_literal
283285
("kong.foo in 1.1.1.1", "(kong.foo in 1.1.1.1)"),
@@ -298,13 +300,14 @@ mod tests {
298300
),
299301
];
300302
for (input, expected) in tests {
301-
let result = parse(input).unwrap();
303+
let result = parse(input, &mut regex_cache).unwrap();
302304
assert_eq!(result.to_string(), expected);
303305
}
304306
}
305307

306308
#[test]
307309
fn expr_regex() {
310+
let mut regex_cache = HashMap::new();
308311
let tests = vec![
309312
// regex_literal
310313
(
@@ -318,13 +321,14 @@ mod tests {
318321
),
319322
];
320323
for (input, expected) in tests {
321-
let result = parse(input).unwrap();
324+
let result = parse(input, &mut regex_cache).unwrap();
322325
assert_eq!(result.to_string(), expected);
323326
}
324327
}
325328

326329
#[test]
327330
fn expr_digits() {
331+
let mut regex_cache = HashMap::new();
328332
let tests = vec![
329333
// dec literal
330334
("kong.foo.foo7 == 123", "(kong.foo.foo7 == 123)"),
@@ -340,13 +344,14 @@ mod tests {
340344
("kong.foo.foo12 == -0123", "(kong.foo.foo12 == -83)"),
341345
];
342346
for (input, expected) in tests {
343-
let result = parse(input).unwrap();
347+
let result = parse(input, &mut regex_cache).unwrap();
344348
assert_eq!(result.to_string(), expected);
345349
}
346350
}
347351

348352
#[test]
349353
fn expr_transformations() {
354+
let mut regex_cache = HashMap::new();
350355
let tests = vec![
351356
// lower
352357
(
@@ -360,13 +365,14 @@ mod tests {
360365
),
361366
];
362367
for (input, expected) in tests {
363-
let result = parse(input).unwrap();
368+
let result = parse(input, &mut regex_cache).unwrap();
364369
assert_eq!(result.to_string(), expected);
365370
}
366371
}
367372

368373
#[test]
369374
fn expr_transformations_nested() {
375+
let mut regex_cache = HashMap::new();
370376
let tests = vec![
371377
// lower + lower
372378
(
@@ -390,35 +396,37 @@ mod tests {
390396
),
391397
];
392398
for (input, expected) in tests {
393-
let result = parse(input).unwrap();
399+
let result = parse(input, &mut regex_cache).unwrap();
394400
assert_eq!(result.to_string(), expected);
395401
}
396402
}
397403

398404
#[test]
399405
fn str_unicode_test() {
406+
let mut regex_cache = HashMap::new();
400407
let tests = vec![
401408
// cjk chars
402409
("t_msg in \"你好\"", "(t_msg in \"你好\")"),
403410
// 0xXXX unicode
404411
("t_msg in \"\u{4f60}\u{597d}\"", "(t_msg in \"你好\")"),
405412
];
406413
for (input, expected) in tests {
407-
let result = parse(input).unwrap();
414+
let result = parse(input, &mut regex_cache).unwrap();
408415
assert_eq!(result.to_string(), expected);
409416
}
410417
}
411418

412419
#[test]
413420
fn rawstr_test() {
421+
let mut regex_cache = HashMap::new();
414422
let tests = vec![
415423
// invalid escape sequence
416424
(r##"a == r#"/path/to/\d+"#"##, r#"(a == "/path/to/\d+")"#),
417425
// valid escape sequence
418426
(r##"a == r#"/path/to/\n+"#"##, r#"(a == "/path/to/\n+")"#),
419427
];
420428
for (input, expected) in tests {
421-
let result = parse(input).unwrap();
429+
let result = parse(input, &mut regex_cache).unwrap();
422430
assert_eq!(result.to_string(), expected);
423431
}
424432
}

src/parser.rs

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ use pest::pratt_parser::Assoc as AssocNew;
1111
use pest::pratt_parser::{Op, PrattParser};
1212
use pest::Parser;
1313
use regex::Regex;
14+
use std::collections::HashMap;
1415
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
16+
use std::rc::Rc;
1517

1618
type ParseResult<T> = Result<T, ParseError<Rule>>;
1719

@@ -61,12 +63,16 @@ impl ATCParser {
6163
}
6264
// matcher = { SOI ~ expression ~ EOI }
6365
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
64-
fn parse_matcher(&mut self, source: &str) -> ParseResult<Expression> {
66+
fn parse_matcher(
67+
&mut self,
68+
source: &str,
69+
regex_cache: &mut HashMap<String, Rc<Regex>>,
70+
) -> ParseResult<Expression> {
6571
let pairs = ATCParser::parse(Rule::matcher, source)?;
6672
let expr_pair = pairs.peek().unwrap().into_inner().peek().unwrap();
6773
let rule = expr_pair.as_rule();
6874
match rule {
69-
Rule::expression => parse_expression(expr_pair, &self.pratt_parser),
75+
Rule::expression => parse_expression(expr_pair, &self.pratt_parser, regex_cache),
7076
_ => unreachable!(),
7177
}
7278
}
@@ -204,7 +210,10 @@ fn parse_int_literal(pair: Pair<Rule>) -> ParseResult<i64> {
204210

205211
// predicate = { lhs ~ binary_operator ~ rhs }
206212
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
207-
fn parse_predicate(pair: Pair<Rule>) -> ParseResult<Predicate> {
213+
fn parse_predicate(
214+
pair: Pair<Rule>,
215+
regex_cache: &mut HashMap<String, Rc<Regex>>,
216+
) -> ParseResult<Predicate> {
208217
let mut pairs = pair.into_inner();
209218
let lhs = parse_lhs(pairs.next().unwrap())?;
210219
let op = parse_binary_operator(pairs.next().unwrap());
@@ -214,23 +223,22 @@ fn parse_predicate(pair: Pair<Rule>) -> ParseResult<Predicate> {
214223
lhs,
215224
rhs: if op == BinaryOperator::Regex {
216225
if let Value::String(s) = rhs {
217-
let r = Regex::new(&s).map_err(|e| {
218-
ParseError::new_from_span(
219-
ErrorVariant::CustomError {
220-
message: e.to_string(),
221-
},
222-
rhs_pair.as_span(),
223-
)
224-
})?;
225-
226-
Value::Regex(r)
226+
let regex_rc = match regex_cache.get(&s) {
227+
Some(stored_regex_rc) => stored_regex_rc.clone(),
228+
None => {
229+
let r = Regex::new(&s).into_parse_result(&rhs_pair)?;
230+
231+
let rc = Rc::new(r);
232+
233+
regex_cache.insert(s, rc.clone());
234+
rc
235+
}
236+
};
237+
238+
Value::Regex(regex_rc)
227239
} else {
228-
return Err(ParseError::new_from_span(
229-
ErrorVariant::CustomError {
230-
message: "regex operator can only be used with String operands".to_string(),
231-
},
232-
rhs_pair.as_span(),
233-
));
240+
return Err("regex operator can only be used with String operands")
241+
.into_parse_result(&rhs_pair);
234242
}
235243
} else {
236244
rhs
@@ -289,39 +297,53 @@ fn parse_binary_operator(pair: Pair<Rule>) -> BinaryOperator {
289297
fn parse_parenthesised_expression(
290298
pair: Pair<Rule>,
291299
pratt: &PrattParser<Rule>,
300+
regex_cache: &mut HashMap<String, Rc<Regex>>,
292301
) -> ParseResult<Expression> {
293302
let mut pairs = pair.into_inner();
294303
let pair = pairs.next().unwrap();
295304
let rule = pair.as_rule();
296305
match rule {
297-
Rule::expression => parse_expression(pair, pratt),
306+
Rule::expression => parse_expression(pair, pratt, regex_cache),
298307
Rule::not_op => Ok(Expression::Logical(Box::new(LogicalExpression::Not(
299-
parse_expression(pairs.next().unwrap(), pratt)?,
308+
parse_expression(pairs.next().unwrap(), pratt, regex_cache)?,
300309
)))),
301310
_ => unreachable!(),
302311
}
303312
}
304313

305314
// term = { predicate | parenthesised_expression }
306315
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
307-
fn parse_term(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<Expression> {
316+
fn parse_term(
317+
pair: Pair<Rule>,
318+
pratt: &PrattParser<Rule>,
319+
regex_cache: &mut HashMap<String, Rc<Regex>>,
320+
) -> ParseResult<Expression> {
308321
let pairs = pair.into_inner();
309322
let inner_rule = pairs.peek().unwrap();
310323
let rule = inner_rule.as_rule();
311324
match rule {
312-
Rule::predicate => Ok(Expression::Predicate(parse_predicate(inner_rule)?)),
313-
Rule::parenthesised_expression => parse_parenthesised_expression(inner_rule, pratt),
325+
Rule::predicate => Ok(Expression::Predicate(parse_predicate(
326+
inner_rule,
327+
regex_cache,
328+
)?)),
329+
Rule::parenthesised_expression => {
330+
parse_parenthesised_expression(inner_rule, pratt, regex_cache)
331+
}
314332
_ => unreachable!(),
315333
}
316334
}
317335

318336
// expression = { term ~ ( logical_operator ~ term )* }
319337
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
320-
fn parse_expression(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<Expression> {
338+
fn parse_expression(
339+
pair: Pair<Rule>,
340+
pratt: &PrattParser<Rule>,
341+
regex_cache: &mut HashMap<String, Rc<Regex>>,
342+
) -> ParseResult<Expression> {
321343
let pairs = pair.into_inner();
322344
pratt
323345
.map_primary(|operand| match operand.as_rule() {
324-
Rule::term => parse_term(operand, pratt),
346+
Rule::term => parse_term(operand, pratt, regex_cache),
325347
_ => unreachable!(),
326348
})
327349
.map_infix(|lhs, op, rhs| {
@@ -335,8 +357,11 @@ fn parse_expression(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<
335357
}
336358

337359
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
338-
pub fn parse(source: &str) -> ParseResult<Expression> {
339-
ATCParser::new().parse_matcher(source)
360+
pub fn parse(
361+
source: &str,
362+
regex_cache: &mut HashMap<String, Rc<Regex>>,
363+
) -> ParseResult<Expression> {
364+
ATCParser::new().parse_matcher(source, regex_cache)
340365
}
341366

342367
#[cfg(test)]
@@ -345,16 +370,19 @@ mod tests {
345370

346371
#[test]
347372
fn test_bad_syntax() {
373+
let mut regex_cache = HashMap::new();
348374
assert_eq!(
349-
parse("! a == 1").unwrap_err().to_string(),
375+
parse("! a == 1", &mut regex_cache).unwrap_err().to_string(),
350376
" --> 1:1\n |\n1 | ! a == 1\n | ^---\n |\n = expected term"
351377
);
352378
assert_eq!(
353-
parse("a == 1 || ! b == 2").unwrap_err().to_string(),
379+
parse("a == 1 || ! b == 2", &mut regex_cache)
380+
.unwrap_err()
381+
.to_string(),
354382
" --> 1:11\n |\n1 | a == 1 || ! b == 2\n | ^---\n |\n = expected term"
355383
);
356384
assert_eq!(
357-
parse("(a == 1 || b == 2) && ! c == 3")
385+
parse("(a == 1 || b == 2) && ! c == 3", &mut regex_cache)
358386
.unwrap_err()
359387
.to_string(),
360388
" --> 1:23\n |\n1 | (a == 1 || b == 2) && ! c == 3\n | ^---\n |\n = expected term"

0 commit comments

Comments
 (0)