Skip to content

Commit 2a270cf

Browse files
committed
perf(*): cache regex predicates with rc using router attribute
KAG-3182
1 parent d3136e5 commit 2a270cf

File tree

5 files changed

+115
-45
lines changed

5 files changed

+115
-45
lines changed

src/ast.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::schema::Schema;
22
use cidr::IpCidr;
33
use regex::Regex;
4-
use std::net::IpAddr;
4+
use std::{net::IpAddr, rc::Rc};
55

66
#[cfg(feature = "serde")]
77
use serde::{Deserialize, Serialize};
@@ -53,7 +53,7 @@ pub enum Value {
5353
IpAddr(IpAddr),
5454
Int(i64),
5555
#[cfg_attr(feature = "serde", serde(with = "serde_regex"))]
56-
Regex(Regex),
56+
Regex(Rc<Regex>),
5757
}
5858

5959
impl PartialEq for Value {
@@ -174,7 +174,7 @@ pub struct Predicate {
174174
mod tests {
175175
use super::*;
176176
use crate::parser::parse;
177-
use std::fmt;
177+
use std::{collections::HashMap, fmt};
178178

179179
impl fmt::Display for Expression {
180180
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -277,6 +277,7 @@ mod tests {
277277

278278
#[test]
279279
fn expr_op_and_prec() {
280+
let mut regex_cache = HashMap::new();
280281
let tests = vec![
281282
("a > 0", "(a > 0)"),
282283
("a in \"abc\"", "(a in \"abc\")"),
@@ -308,13 +309,14 @@ mod tests {
308309
),
309310
];
310311
for (input, expected) in tests {
311-
let result = parse(input).unwrap();
312+
let result = parse(input, &mut regex_cache).unwrap();
312313
assert_eq!(result.to_string(), expected);
313314
}
314315
}
315316

316317
#[test]
317318
fn expr_var_name_and_ip() {
319+
let mut regex_cache = HashMap::new();
318320
let tests = vec![
319321
// ipv4_literal
320322
("kong.foo in 1.1.1.1", "(kong.foo in 1.1.1.1)"),
@@ -335,13 +337,14 @@ mod tests {
335337
),
336338
];
337339
for (input, expected) in tests {
338-
let result = parse(input).unwrap();
340+
let result = parse(input, &mut regex_cache).unwrap();
339341
assert_eq!(result.to_string(), expected);
340342
}
341343
}
342344

343345
#[test]
344346
fn expr_regex() {
347+
let mut regex_cache = HashMap::new();
345348
let tests = vec![
346349
// regex_literal
347350
(
@@ -355,13 +358,14 @@ mod tests {
355358
),
356359
];
357360
for (input, expected) in tests {
358-
let result = parse(input).unwrap();
361+
let result = parse(input, &mut regex_cache).unwrap();
359362
assert_eq!(result.to_string(), expected);
360363
}
361364
}
362365

363366
#[test]
364367
fn expr_digits() {
368+
let mut regex_cache = HashMap::new();
365369
let tests = vec![
366370
// dec literal
367371
("kong.foo.foo7 == 123", "(kong.foo.foo7 == 123)"),
@@ -377,13 +381,14 @@ mod tests {
377381
("kong.foo.foo12 == -0123", "(kong.foo.foo12 == -83)"),
378382
];
379383
for (input, expected) in tests {
380-
let result = parse(input).unwrap();
384+
let result = parse(input, &mut regex_cache).unwrap();
381385
assert_eq!(result.to_string(), expected);
382386
}
383387
}
384388

385389
#[test]
386390
fn expr_transformations() {
391+
let mut regex_cache = HashMap::new();
387392
let tests = vec![
388393
// lower
389394
(
@@ -397,13 +402,14 @@ mod tests {
397402
),
398403
];
399404
for (input, expected) in tests {
400-
let result = parse(input).unwrap();
405+
let result = parse(input, &mut regex_cache).unwrap();
401406
assert_eq!(result.to_string(), expected);
402407
}
403408
}
404409

405410
#[test]
406411
fn expr_transformations_nested() {
412+
let mut regex_cache = HashMap::new();
407413
let tests = vec![
408414
// lower + lower
409415
(
@@ -427,35 +433,37 @@ mod tests {
427433
),
428434
];
429435
for (input, expected) in tests {
430-
let result = parse(input).unwrap();
436+
let result = parse(input, &mut regex_cache).unwrap();
431437
assert_eq!(result.to_string(), expected);
432438
}
433439
}
434440

435441
#[test]
436442
fn str_unicode_test() {
443+
let mut regex_cache = HashMap::new();
437444
let tests = vec![
438445
// cjk chars
439446
("t_msg in \"你好\"", "(t_msg in \"你好\")"),
440447
// 0xXXX unicode
441448
("t_msg in \"\u{4f60}\u{597d}\"", "(t_msg in \"你好\")"),
442449
];
443450
for (input, expected) in tests {
444-
let result = parse(input).unwrap();
451+
let result = parse(input, &mut regex_cache).unwrap();
445452
assert_eq!(result.to_string(), expected);
446453
}
447454
}
448455

449456
#[test]
450457
fn rawstr_test() {
458+
let mut regex_cache = HashMap::new();
451459
let tests = vec![
452460
// invalid escape sequence
453461
(r##"a == r#"/path/to/\d+"#"##, r#"(a == "/path/to/\d+")"#),
454462
// valid escape sequence
455463
(r##"a == r#"/path/to/\n+"#"##, r#"(a == "/path/to/\n+")"#),
456464
];
457465
for (input, expected) in tests {
458-
let result = parse(input).unwrap();
466+
let result = parse(input, &mut regex_cache).unwrap();
459467
assert_eq!(result.to_string(), expected);
460468
}
461469
}

src/ffi/expression.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::ffi::ERR_BUF_MAX_LEN;
33
use crate::schema::Schema;
44
use bitflags::bitflags;
55
use std::cmp::min;
6+
use std::collections::HashMap;
67
use std::ffi;
78
use std::os::raw::c_char;
89
use std::slice::from_raw_parts_mut;
@@ -163,7 +164,7 @@ pub unsafe extern "C" fn expression_validate(
163164
let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN);
164165

165166
// Parse the expression
166-
let result = parse(atc).map_err(|e| e.to_string());
167+
let result = parse(atc, &mut HashMap::new()).map_err(|e| e.to_string());
167168
if let Err(e) = result {
168169
let errlen = min(e.len(), *errbuf_len);
169170
errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]);

src/parser.rs

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ use pest::pratt_parser::Assoc as AssocNew;
1111
use pest::pratt_parser::{Op, PrattParser};
1212
use pest::Parser;
1313
use regex::Regex;
14+
use std::collections::HashMap;
1415
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
16+
use std::rc::Rc;
1517

1618
type ParseResult<T> = Result<T, ParseError<Rule>>;
1719

@@ -61,12 +63,16 @@ impl ATCParser {
6163
}
6264
// matcher = { SOI ~ expression ~ EOI }
6365
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
64-
fn parse_matcher(&mut self, source: &str) -> ParseResult<Expression> {
66+
fn parse_matcher(
67+
&mut self,
68+
source: &str,
69+
regex_cache: &mut HashMap<String, Rc<Regex>>,
70+
) -> ParseResult<Expression> {
6571
let pairs = ATCParser::parse(Rule::matcher, source)?;
6672
let expr_pair = pairs.peek().unwrap().into_inner().peek().unwrap();
6773
let rule = expr_pair.as_rule();
6874
match rule {
69-
Rule::expression => parse_expression(expr_pair, &self.pratt_parser),
75+
Rule::expression => parse_expression(expr_pair, &self.pratt_parser, regex_cache),
7076
_ => unreachable!(),
7177
}
7278
}
@@ -204,7 +210,10 @@ fn parse_int_literal(pair: Pair<Rule>) -> ParseResult<i64> {
204210

205211
// predicate = { lhs ~ binary_operator ~ rhs }
206212
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
207-
fn parse_predicate(pair: Pair<Rule>) -> ParseResult<Predicate> {
213+
fn parse_predicate(
214+
pair: Pair<Rule>,
215+
regex_cache: &mut HashMap<String, Rc<Regex>>,
216+
) -> ParseResult<Predicate> {
208217
let mut pairs = pair.into_inner();
209218
let lhs = parse_lhs(pairs.next().unwrap())?;
210219
let op = parse_binary_operator(pairs.next().unwrap());
@@ -222,16 +231,19 @@ fn parse_predicate(pair: Pair<Rule>) -> ParseResult<Predicate> {
222231
));
223232
};
224233

225-
let r = Regex::new(&s).map_err(|e| {
226-
ParseError::new_from_span(
227-
ErrorVariant::CustomError {
228-
message: e.to_string(),
229-
},
230-
rhs_pair.as_span(),
231-
)
232-
})?;
234+
let regex_rc = match regex_cache.get(&s) {
235+
Some(stored_regex_rc) => stored_regex_rc.clone(),
236+
None => {
237+
let r = Regex::new(&s).into_parse_result(&rhs_pair)?;
238+
239+
let rc = Rc::new(r);
233240

234-
Value::Regex(r)
241+
regex_cache.insert(s, rc.clone());
242+
rc
243+
}
244+
};
245+
246+
Value::Regex(regex_rc)
235247
} else {
236248
rhs
237249
},
@@ -290,39 +302,53 @@ fn parse_binary_operator(pair: Pair<Rule>) -> BinaryOperator {
290302
fn parse_parenthesised_expression(
291303
pair: Pair<Rule>,
292304
pratt: &PrattParser<Rule>,
305+
regex_cache: &mut HashMap<String, Rc<Regex>>,
293306
) -> ParseResult<Expression> {
294307
let mut pairs = pair.into_inner();
295308
let pair = pairs.next().unwrap();
296309
let rule = pair.as_rule();
297310
match rule {
298-
Rule::expression => parse_expression(pair, pratt),
311+
Rule::expression => parse_expression(pair, pratt, regex_cache),
299312
Rule::not_op => Ok(Expression::Logical(Box::new(LogicalExpression::Not(
300-
parse_expression(pairs.next().unwrap(), pratt)?,
313+
parse_expression(pairs.next().unwrap(), pratt, regex_cache)?,
301314
)))),
302315
_ => unreachable!(),
303316
}
304317
}
305318

306319
// term = { predicate | parenthesised_expression }
307320
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
308-
fn parse_term(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<Expression> {
321+
fn parse_term(
322+
pair: Pair<Rule>,
323+
pratt: &PrattParser<Rule>,
324+
regex_cache: &mut HashMap<String, Rc<Regex>>,
325+
) -> ParseResult<Expression> {
309326
let pairs = pair.into_inner();
310327
let inner_rule = pairs.peek().unwrap();
311328
let rule = inner_rule.as_rule();
312329
match rule {
313-
Rule::predicate => Ok(Expression::Predicate(parse_predicate(inner_rule)?)),
314-
Rule::parenthesised_expression => parse_parenthesised_expression(inner_rule, pratt),
330+
Rule::predicate => Ok(Expression::Predicate(parse_predicate(
331+
inner_rule,
332+
regex_cache,
333+
)?)),
334+
Rule::parenthesised_expression => {
335+
parse_parenthesised_expression(inner_rule, pratt, regex_cache)
336+
}
315337
_ => unreachable!(),
316338
}
317339
}
318340

319341
// expression = { term ~ ( logical_operator ~ term )* }
320342
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
321-
fn parse_expression(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<Expression> {
343+
fn parse_expression(
344+
pair: Pair<Rule>,
345+
pratt: &PrattParser<Rule>,
346+
regex_cache: &mut HashMap<String, Rc<Regex>>,
347+
) -> ParseResult<Expression> {
322348
let pairs = pair.into_inner();
323349
pratt
324350
.map_primary(|operand| match operand.as_rule() {
325-
Rule::term => parse_term(operand, pratt),
351+
Rule::term => parse_term(operand, pratt, regex_cache),
326352
_ => unreachable!(),
327353
})
328354
.map_infix(|lhs, op, rhs| {
@@ -336,8 +362,11 @@ fn parse_expression(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<
336362
}
337363

338364
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
339-
pub fn parse(source: &str) -> ParseResult<Expression> {
340-
ATCParser::new().parse_matcher(source)
365+
pub fn parse(
366+
source: &str,
367+
regex_cache: &mut HashMap<String, Rc<Regex>>,
368+
) -> ParseResult<Expression> {
369+
ATCParser::new().parse_matcher(source, regex_cache)
341370
}
342371

343372
#[cfg(test)]
@@ -346,16 +375,19 @@ mod tests {
346375

347376
#[test]
348377
fn test_bad_syntax() {
378+
let mut regex_cache = HashMap::new();
349379
assert_eq!(
350-
parse("! a == 1").unwrap_err().to_string(),
380+
parse("! a == 1", &mut regex_cache).unwrap_err().to_string(),
351381
" --> 1:1\n |\n1 | ! a == 1\n | ^---\n |\n = expected term"
352382
);
353383
assert_eq!(
354-
parse("a == 1 || ! b == 2").unwrap_err().to_string(),
384+
parse("a == 1 || ! b == 2", &mut regex_cache)
385+
.unwrap_err()
386+
.to_string(),
355387
" --> 1:11\n |\n1 | a == 1 || ! b == 2\n | ^---\n |\n = expected term"
356388
);
357389
assert_eq!(
358-
parse("(a == 1 || b == 2) && ! c == 3")
390+
parse("(a == 1 || b == 2) && ! c == 3", &mut regex_cache)
359391
.unwrap_err()
360392
.to_string(),
361393
" --> 1:23\n |\n1 | (a == 1 || b == 2) && ! c == 3\n | ^---\n |\n = expected term"

0 commit comments

Comments
 (0)