diff --git a/Cargo.lock b/Cargo.lock index fb34bf1f..3986b55e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" [[package]] name = "atc-router" @@ -35,6 +35,7 @@ dependencies = [ "pest_derive", "regex", "serde", + "serde_json", "serde_regex", "uuid", ] @@ -251,6 +252,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "half" version = "2.4.1" @@ -310,9 +322,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "log" @@ -422,9 +434,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -460,9 +472,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -504,18 +516,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.210" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", @@ -524,9 +536,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -557,9 +569,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ "proc-macro2", "quote", @@ -568,18 +580,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", @@ -616,9 +628,12 @@ checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +dependencies = [ + "getrandom", +] [[package]] name = "version_check" @@ -636,6 +651,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "wasm-bindgen" version = "0.2.95" diff --git a/Cargo.toml b/Cargo.toml index 5782964c..f7fef601 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,13 @@ fnv = "1" [dev-dependencies] criterion = "0.*" +serde_json = "1" +serde = "1" +uuid = {version = "1.8", features = ["v4"]} + +[[bench]] +name = "misc_match" +harness = false [lib] crate-type = ["lib", "cdylib", "staticlib"] diff --git a/README.md b/README.md index cc441902..89105f80 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ ATC Router library for Kong. * [resty.router.context](#restyroutercontext) * [new](#new) * [add\_value](#add_value) + * [add\_value\_by\_index](#add_value_by_index) * [get\_result](#get_result) * [reset](#reset) * [Copyright and license](#copyright-and-license) @@ -186,9 +187,9 @@ none of the matcher matched. **context:** *any* -Returns the currently used field names by all matchers inside the router as -an Lua array. It can help reduce unnecessarily producing values that are not -actually used by the user supplied matchers. +Returns the currently used {field name: field index} map by all matchers inside +the router as an Lua table. It can help reduce unnecessarily producing values +that are not actually used by the user supplied matchers. [Back to TOC](#table-of-contents) @@ -209,12 +210,12 @@ Returns the fields used in the provided expression when the expression is valid. ### new -**syntax:** *c = context.new(schema)* +**syntax:** *c = context.new(router)* **context:** *any* Create a new context instance that can later be used for storing contextual information. -for router matches. `schema` must refer to an existing schema instance. +for router matches. `router` must refer to an existing router instance. [Back to TOC](#table-of-contents) @@ -232,6 +233,28 @@ If an error occurred, `nil` and a string describing the error will be returned. [Back to TOC](#table-of-contents) +### add\_value\_by\_index + +**syntax:** *res, err = c:add_value_by_index(field, value, index)* + +**context:** *any* + +Provides `value` for `field` inside the context. + +Use `index` got from `r:get_fields()` to provide `value` for `field`. + +This method is faster than `add_value`, but notice that users should gurantee + +the index got from `r:get_fields()` is not stale. e.g. `router` fields changed + +by `remove_matcher` may cause field indexes got previously invalid. + +Returns `true` if field exists and value has successfully been provided. + +If an error occurred, `nil` and a string describing the error will be returned. + +[Back to TOC](#table-of-contents) + ### get\_result **syntax:** *uuid, matched_value, captures = c:get_result(matched_field)* diff --git a/benches/build.rs b/benches/build.rs index b42fc8c4..229c39e4 100644 --- a/benches/build.rs +++ b/benches/build.rs @@ -20,9 +20,6 @@ fn criterion_benchmark(c: &mut Criterion) { let mut schema = Schema::default(); schema.add_field("a", Type::Int); - let mut context = Context::new(&schema); - context.add_value("a", Value::Int(N as i64)); - c.bench_function("Build Router", |b| { b.iter_with_large_drop(|| { let mut router = Router::new(&schema); diff --git a/benches/data.json b/benches/data.json new file mode 100644 index 00000000..2d43d68e --- /dev/null +++ b/benches/data.json @@ -0,0 +1,55 @@ +{ + "rules": [ + "net.protocol == \"http\"", + "tls.sni == \"server1\"", + "http.method == \"GET\"", + "http.host == \"example.com\"", + "http.path == \"/foo\"", + "http.path.segments.1 == \"bar\"", + "http.path.segments.0_1 == \"foo/bar\"", + "http.path.segments.len == 2", + "http.headers.foo_bar == \"whatever\"", + "net.dst.port == 8443", + "net.src.ip == 192.168.1.1", + "net.src.ip in 192.168.1.0/24" + ], + "match_keys": [ + "net.protocol", + "tls.sni", + "http.method", + "http.host", + "http.path", + "http.path.segments.1", + "http.path.segments.0_1", + "http.path.segments.len", + "http.headers.foo_bar", + "net.dst.port", + "net.src.ip" + ], + "match_values": [ + "http", + "server1", + "GET", + "example.com", + "/foo", + "bar\"", + "foo/bar\"", + 2, + "whatever", + 8443, + ["192.168.1.1"] + ], + "not_match_values": [ + "https", + "server2", + "POST", + "example_foo.com", + "/fooo", + "/barr\"", + "/fooo/bar\"", + 3, + "whatever_wrong", + 18443, + ["192.168.2.1"] + ] +} \ No newline at end of file diff --git a/benches/match_mix.rs b/benches/match_mix.rs index 056c7f69..d7a6a535 100644 --- a/benches/match_mix.rs +++ b/benches/match_mix.rs @@ -34,7 +34,7 @@ fn criterion_benchmark(c: &mut Criterion) { router.add_matcher(N - i, uuid, &expr).unwrap(); } - let mut ctx_match = Context::new(&schema); + let mut ctx_match = Context::new(&router); ctx_match.add_value( "http.path", atc_router::ast::Value::String("hello49999".to_string()), diff --git a/benches/misc_match.rs b/benches/misc_match.rs new file mode 100644 index 00000000..fdcff753 --- /dev/null +++ b/benches/misc_match.rs @@ -0,0 +1,136 @@ +use atc_router::{ + ast::Type, + ast::Value, + context::Context, + router::Router, + schema::{self, Schema}, +}; +use criterion::{criterion_group, criterion_main, Criterion}; +use serde::{Deserialize, Serialize}; +use serde_json; +use std::env; +use std::fs; +use std::net::{IpAddr, Ipv4Addr}; +use std::{hint::black_box, str::FromStr}; +use uuid::Uuid; + +// To run this benchmark, execute the following command: +// ```shell +// cargo bench --bench misc_match +// ``` + +#[derive(Serialize, Deserialize)] +struct TestData { + rules: Vec, + match_keys: Vec, + match_values: Vec, + not_match_values: Vec, +} + +// prepare match rules, context keys, context values from data.json file +fn prepare_data() -> TestData { + let cwd = env::current_dir().unwrap(); + let file_str = + fs::read_to_string(cwd.join("benches/data.json")).expect("unable to open data.json"); + serde_json::from_str(&file_str).unwrap() +} + +// setup Schema +fn setup_schema() -> schema::Schema { + let mut s = Schema::default(); + s.add_field("net.protocol", Type::String); + s.add_field("tls.sni", Type::String); + s.add_field("http.method", Type::String); + s.add_field("http.host", Type::String); + s.add_field("http.path", Type::String); + s.add_field("http.path.segments.*", Type::String); + s.add_field("http.path.segments.len", Type::Int); + s.add_field("http.headers.*", Type::String); + s.add_field("net.dst.port", Type::Int); + s.add_field("net.src.ip", Type::IpAddr); + s +} + +// setup matchers, which be added from priority 100 with descending order +fn setup_matchers(r: &mut Router, data: &TestData) { + let mut pri = 100; + for v in &data.rules { + let id = Uuid::new_v4(); + let _ = r.add_matcher(pri, id, v.as_str().unwrap()); + pri -= 1; + } +} + +// mock contexts with field values passed in from json data +fn setup_context(ctx: &mut Context, data: &TestData, test_match: bool) { + let values = if test_match { + &data.match_values + } else { + &data.not_match_values + }; + for (i, v) in values.iter().enumerate() { + match v { + serde_json::Value::String(s) => { + ctx.add_value_by_index(i, Value::String(s.to_string())); + } + serde_json::Value::Number(n) => { + ctx.add_value_by_index(i, Value::Int(n.as_i64().unwrap())); + } + serde_json::Value::Array(l) => { + ctx.add_value_by_index( + i, + Value::IpAddr(IpAddr::V4( + Ipv4Addr::from_str(l[0].as_str().unwrap()).unwrap(), + )), + ); + } + _ => panic!("incorrect data type"), + } + } +} + +fn router_match(router: &Router, ctx: &mut Context, expected: bool) { + assert_eq!(router.execute(ctx), expected); +} + +fn matchers_batch_handling(s: &Schema) { + let mut r = Router::new(&s); + let pri_max = 10000; + let mut ids = vec![]; + for pri in 0..pri_max { + let id: Uuid = Uuid::new_v4(); + let exp = format!(r#"http.path.segments.{} == "/bar""#, pri.to_string()); + assert!(r.add_matcher(pri, id, exp.as_str()).is_ok()); + ids.push((pri, id)); + } + + for (pri, id) in ids { + assert!(r.remove_matcher(pri, id)); + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let data = prepare_data(); + let s = setup_schema(); + let mut r = Router::new(&s); + setup_matchers(&mut r, &data); + + let mut ctx = Context::new(&r); + setup_context(&mut ctx, &data, true); + c.bench_function("route match all", |b| { + b.iter(|| router_match(black_box(&r), black_box(&mut ctx), black_box(true))) + }); + + let mut ctx = Context::new(&r); + setup_context(&mut ctx, &data, false); + c.bench_function("route mismatch all", |b| { + b.iter(|| router_match(black_box(&r), black_box(&mut ctx), black_box(false))) + }); + + c.bench_function("route matchers batch create and delete", |b| { + b.iter(|| matchers_batch_handling(black_box(&s))) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/benches/not_match_mix.rs b/benches/not_match_mix.rs index bd93432a..53a07527 100644 --- a/benches/not_match_mix.rs +++ b/benches/not_match_mix.rs @@ -34,7 +34,7 @@ fn criterion_benchmark(c: &mut Criterion) { router.add_matcher(N - i, uuid, &expr).unwrap(); } - let mut ctx_match = Context::new(&schema); + let mut ctx_match = Context::new(&router); ctx_match.add_value( "http.path", atc_router::ast::Value::String("hello49999".to_string()), diff --git a/benches/string.rs b/benches/string.rs index 3516ff38..542d8d5a 100644 --- a/benches/string.rs +++ b/benches/string.rs @@ -34,7 +34,7 @@ fn criterion_benchmark(c: &mut Criterion) { router.add_matcher(N_MATCHER - i, uuid, &expr).unwrap(); } - let mut context = Context::new(&schema); + let mut context = Context::new(&router); context.add_value("http.path.segments.0_1", "test/run".to_string().into()); context.add_value("http.path.segments.3", "bar".to_string().into()); context.add_value("http.path.segments.len", Value::Int(3 as i64)); diff --git a/benches/test.rs b/benches/test.rs index e361af73..adcc8ac3 100644 --- a/benches/test.rs +++ b/benches/test.rs @@ -29,7 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) { router.add_matcher(N - i, uuid, &expr).unwrap(); } - let mut context = Context::new(&schema); + let mut context = Context::new(&router); context.add_value("a", Value::Int(N as i64)); c.bench_function("Doesn't Match", |b| { diff --git a/lib/resty/router/cdefs.lua b/lib/resty/router/cdefs.lua index 0d0349f0..5d0824fb 100644 --- a/lib/resty/router/cdefs.lua +++ b/lib/resty/router/cdefs.lua @@ -68,9 +68,10 @@ bool router_execute(const struct Router *router, struct Context *context); uintptr_t router_get_fields(const struct Router *router, const uint8_t **fields, - uintptr_t *fields_len); + uintptr_t *fields_len, + uintptr_t *indexes); -struct Context *context_new(const struct Schema *schema); +struct Context *context_new(const struct Router *router); void context_free(struct Context *context); @@ -80,6 +81,12 @@ bool context_add_value(struct Context *context, uint8_t *errbuf, uintptr_t *errbuf_len); +bool context_add_value_by_index(struct Context *context, + uintptr_t index, + const struct CValue *value, + uint8_t *errbuf, + uintptr_t *errbuf_len); + void context_reset(struct Context *context); intptr_t context_get_result(const struct Context *context, diff --git a/lib/resty/router/context.lua b/lib/resty/router/context.lua index c88d6552..bd709894 100644 --- a/lib/resty/router/context.lua +++ b/lib/resty/router/context.lua @@ -26,23 +26,23 @@ local clib = cdefs.clib local context_free = cdefs.context_free -function _M.new(schema) - local context = clib.context_new(schema.schema) +function _M.new(router) + local context = clib.context_new(router.router) local c = setmetatable({ context = ffi_gc(context, context_free), - schema = schema, + schema = router.schema, }, _MT) return c end -function _M:add_value(field, value) +local function add_value_impl(ctx, field, value, index) if not value then return true end - local typ, err = self.schema:get_field_type(field) + local typ, err = ctx.schema:get_field_type(field) if not typ then return nil, err end @@ -63,9 +63,14 @@ function _M:add_value(field, value) local errbuf = get_string_buf(ERR_BUF_MAX_LEN) local errbuf_len = get_size_ptr() - errbuf_len[0] = ERR_BUF_MAX_LEN - - if clib.context_add_value(self.context, field, CACHED_VALUE, errbuf, errbuf_len) == false then + errbuf_len[0] = ERR_BUF_MAX_LEN + local res + if index ~= nil then + res = clib.context_add_value_by_index(ctx.context, index, CACHED_VALUE, errbuf, errbuf_len) + else + res = clib.context_add_value(ctx.context, field, CACHED_VALUE, errbuf, errbuf_len) + end + if res == false then return nil, ffi_string(errbuf, errbuf_len[0]) end @@ -73,6 +78,16 @@ function _M:add_value(field, value) end +function _M:add_value(field, value) + return add_value_impl(self, field, value) +end + + +function _M:add_value_by_index(field, value, index) + return add_value_impl(self, field, value, index) +end + + function _M:get_result(matched_field) local captures_len = tonumber(clib.context_get_result( self.context, nil, nil, nil, nil, nil, nil, nil, nil)) diff --git a/lib/resty/router/router.lua b/lib/resty/router/router.lua index a914a2db..536c8e82 100644 --- a/lib/resty/router/router.lua +++ b/lib/resty/router/router.lua @@ -75,23 +75,21 @@ end function _M:get_fields() local out = {} - local out_n = 0 local router = self.router - local total = tonumber(clib.router_get_fields(router, nil, nil)) + local total = tonumber(clib.router_get_fields(router, nil, nil, nil)) if total == 0 then return out end local fields = ffi_new("const uint8_t *[?]", total) local fields_len = ffi_new("size_t [?]", total) + local indexes = ffi_new("size_t [?]", total) fields_len[0] = total - clib.router_get_fields(router, fields, fields_len) - + clib.router_get_fields(router, fields, fields_len, indexes) for i = 0, total - 1 do - out_n = out_n + 1 - out[out_n] = ffi_string(fields[i], fields_len[i]) + out[ffi_string(fields[i], fields_len[i])] = indexes[i] end return out diff --git a/src/ast.rs b/src/ast.rs index 42a04752..a2cd68b2 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -104,6 +104,7 @@ pub enum Type { #[derive(Debug, Clone)] pub struct Lhs { pub var_name: String, + pub index: usize, pub transformations: Vec, } diff --git a/src/cir.rs b/src/cir.rs index da1fecbb..f77e19e6 100644 --- a/src/cir.rs +++ b/src/cir.rs @@ -2,8 +2,8 @@ use crate::ast::{Expression, LogicalExpression, Predicate}; use crate::context::{Context, Match}; use crate::interpreter::Execute; +use crate::router::Fields; use crate::semantics::FieldCounter; -use std::collections::HashMap; #[derive(Debug)] pub struct CirProgram { @@ -178,61 +178,103 @@ impl Execute for CirProgram { } impl FieldCounter for CirOperand { - fn add_to_counter(&self, map: &mut HashMap) { - if let CirOperand::Predicate(p) = &self { - *map.entry(p.lhs.var_name.clone()).or_default() += 1 + fn add_to_counter(&mut self, fields: &mut Fields) { + if let CirOperand::Predicate(p) = self { + // 1. fields: increment counter for field + // 2. lhs: assign field index to the LHS + // 3. map: maintain the fields map: {field_name : field_index} + if let Some(index) = fields.map.get(&p.lhs.var_name) { + fields.list[*index].as_mut().unwrap().1 += 1; + p.lhs.index = *index; + } else { + // reuse slots in queue if possible + let new_idx: usize; + if fields.slots.is_empty() { + fields.list.push(Some((p.lhs.var_name.clone(), 1))); + new_idx = fields.list.len() - 1; + } else { + new_idx = fields.slots.pop().unwrap(); + fields.list[new_idx] = Some((p.lhs.var_name.clone(), 1)); + } + fields.map.insert(p.lhs.var_name.clone(), new_idx); + p.lhs.index = new_idx; + } } } - fn remove_from_counter(&self, map: &mut HashMap) { + fn remove_from_counter(&mut self, fields: &mut Fields) { if let CirOperand::Predicate(p) = &self { - let val = map.get_mut(&p.lhs.var_name).unwrap(); - *val -= 1; - - if *val == 0 { - assert!(map.remove(&p.lhs.var_name).is_some()); + let index: usize = p.lhs.index; + // decrement counter of field + fields.list[index].as_mut().unwrap().1 -= 1; + // for field removing, reserve the slot for resue and remove it in map + if fields.list[index].as_mut().unwrap().1 == 0 { + fields.list[index] = None; + fields.slots.push(index); + assert!(fields.map.remove(&p.lhs.var_name).is_some()); } } } } impl FieldCounter for CirInstruction { - fn add_to_counter(&self, map: &mut HashMap) { + fn add_to_counter(&mut self, fields: &mut Fields) { match self { CirInstruction::AndIns(and) => { - and.left.add_to_counter(map); - and.right.add_to_counter(map); + and.left.add_to_counter(fields); + and.right.add_to_counter(fields); } CirInstruction::OrIns(or) => { - or.left.add_to_counter(map); - or.right.add_to_counter(map); + or.left.add_to_counter(fields); + or.right.add_to_counter(fields); } CirInstruction::NotIns(not) => { - not.right.add_to_counter(map); + not.right.add_to_counter(fields); } CirInstruction::Predicate(p) => { - *map.entry(p.lhs.var_name.clone()).or_default() += 1; + // 1. fields: increment counter for field + // 2. lhs: assign field index to the LHS + // 3. map: maintain the fields map: {field_name : field_index} + if let Some(index) = fields.map.get(&p.lhs.var_name) { + fields.list[*index].as_mut().unwrap().1 += 1; + p.lhs.index = *index; + } else { + // reuse slots in queue if possible + let new_idx: usize; + if fields.slots.is_empty() { + fields.list.push(Some((p.lhs.var_name.clone(), 1))); + new_idx = fields.list.len() - 1; + } else { + new_idx = fields.slots.pop().unwrap(); + fields.list[new_idx] = Some((p.lhs.var_name.clone(), 1)); + } + fields.map.insert(p.lhs.var_name.clone(), new_idx); + p.lhs.index = new_idx; + } } } } - fn remove_from_counter(&self, map: &mut HashMap) { + fn remove_from_counter(&mut self, fields: &mut Fields) { match self { CirInstruction::AndIns(and) => { - and.left.remove_from_counter(map); - and.right.remove_from_counter(map); + and.left.remove_from_counter(fields); + and.right.remove_from_counter(fields); } CirInstruction::OrIns(or) => { - or.left.remove_from_counter(map); - or.right.remove_from_counter(map); + or.left.remove_from_counter(fields); + or.right.remove_from_counter(fields); } CirInstruction::NotIns(not) => { - not.right.remove_from_counter(map); + not.right.remove_from_counter(fields); } CirInstruction::Predicate(p) => { - let val = map.get_mut(&p.lhs.var_name).unwrap(); - *val -= 1; - - if *val == 0 { - assert!(map.remove(&p.lhs.var_name).is_some()); + let index: usize = p.lhs.index; + // decrement counter of field + fields.list[index].as_mut().unwrap().1 -= 1; + // for field removing, reserve the slot for resue and remove it in map + if fields.list[index].as_mut().unwrap().1 == 0 { + fields.list[index] = None; + fields.slots.push(index); + assert!(fields.map.remove(&p.lhs.var_name).is_some()); } } } @@ -240,19 +282,51 @@ impl FieldCounter for CirInstruction { } impl FieldCounter for CirProgram { - fn add_to_counter(&self, map: &mut HashMap) { + fn add_to_counter(&mut self, fields: &mut Fields) { self.instructions - .iter() - .for_each(|instruction: &CirInstruction| instruction.add_to_counter(map)); + .iter_mut() + .for_each(|instruction: &mut CirInstruction| instruction.add_to_counter(fields)); } - fn remove_from_counter(&self, map: &mut HashMap) { + fn remove_from_counter(&mut self, fields: &mut Fields) { self.instructions - .iter() - .for_each(|instruction: &CirInstruction| instruction.remove_from_counter(map)); + .iter_mut() + .for_each(|instruction: &mut CirInstruction| instruction.remove_from_counter(fields)); } } +#[cfg(test)] +pub fn get_predicates(cir: &CirProgram) -> Vec<&Predicate> { + let mut predicates = Vec::new(); + cir.instructions.iter().for_each(|ins| match ins { + CirInstruction::AndIns(and) => { + if let CirOperand::Predicate(predicate) = &and.left { + predicates.push(predicate); + } + if let CirOperand::Predicate(predicate) = &and.right { + predicates.push(predicate); + } + } + CirInstruction::OrIns(or) => { + if let CirOperand::Predicate(predicate) = &or.left { + predicates.push(predicate); + } + if let CirOperand::Predicate(predicate) = &or.right { + predicates.push(predicate); + } + } + CirInstruction::NotIns(not) => { + if let CirOperand::Predicate(predicate) = ¬.right { + predicates.push(predicate); + } + } + CirInstruction::Predicate(predicate) => { + predicates.push(predicate); + } + }); + predicates +} + #[cfg(test)] mod tests { use super::*; @@ -260,6 +334,7 @@ mod tests { use crate::ast::Value; use crate::context::Match; use crate::interpreter::Execute; + use crate::router::Router; use crate::schema::Schema; impl Execute for Expression { @@ -292,7 +367,8 @@ mod tests { r#"http.path == "hello" && http.version == "1.1""#, ]; - let mut context = crate::context::Context::new(&schema); + let r = Router::new(&schema); + let mut context = crate::context::Context::new(&r); context.add_value("http.path", crate::ast::Value::String("hello".to_string())); context.add_value("http.version", crate::ast::Value::String("1.1".to_string())); context.add_value("a", Value::Int(3 as i64)); diff --git a/src/context.rs b/src/context.rs index 3d27305f..e13e8e75 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,5 +1,4 @@ -use crate::ast::Value; -use crate::schema::Schema; +use crate::{ast::Value, router::Router}; use fnv::FnvHashMap; use uuid::Uuid; @@ -26,37 +25,166 @@ impl Default for Match { } pub struct Context<'a> { - schema: &'a Schema, - values: FnvHashMap>, + router: &'a Router<'a>, + values: Vec>>, pub result: Option, } impl<'a> Context<'a> { - pub fn new(schema: &'a Schema) -> Self { + pub fn new(router: &'a Router) -> Self { Context { - schema, - values: FnvHashMap::with_hasher(Default::default()), + router, + values: vec![None; router.fields.list.len()], result: None, } } pub fn add_value(&mut self, field: &str, value: Value) { - if &value.my_type() != self.schema.type_of(field).unwrap() { + if &value.my_type() != self.router.schema().type_of(field).unwrap() { panic!("value provided does not match schema"); } + if let Some(index) = self.router.fields.map.get(field) { + if let Some(v) = &mut self.values[*index] { + v.push(value); + } else { + self.values[*index] = Some(vec![value]); + } + } + } + + pub fn add_value_by_index(&mut self, index: usize, value: Value) { + if index >= self.values.len() { + panic!( + "value provided does not match schema: index {}, max fields count {}", + index, + self.values.len() + ); + } - self.values - .entry(field.to_string()) - .or_default() - .push(value); + if let Some(v) = &mut self.values[index] { + v.push(value); + } else { + self.values[index] = Some(vec![value]); + } } - pub fn value_of(&self, field: &str) -> Option<&[Value]> { - self.values.get(field).map(|v| v.as_slice()) + pub fn value_of(&self, index: usize) -> Option<&[Value]> { + if !self.values.is_empty() && self.values[index].is_some() { + Some(self.values[index].as_ref().unwrap().as_slice()) + } else { + None + } } pub fn reset(&mut self) { + let len = self.values.len(); + // reserve the capacity of values for reuse, avoid re-alloc self.values.clear(); + self.values.resize_with(len, Default::default); self.result = None; } } + +#[cfg(test)] +mod tests { + use crate::ast::{Type, Value}; + use crate::context::Context; + use crate::router::Router; + use crate::schema::Schema; + use uuid::Uuid; + + fn setup_matcher(r: &mut Router) -> usize { + let fields_cnt = 3; + for i in 0..fields_cnt { + let id: Uuid = Uuid::new_v4(); + let exp = format!(r#"http.path.segments.{} == "/bar""#, i.to_string()); + let pri = i; + assert!(r.add_matcher(pri, id, exp.as_str()).is_ok()); + } + fields_cnt + } + + #[test] + fn test_context() { + let mut s = Schema::default(); + s.add_field("http.path.segments.*", Type::String); + let mut r = Router::new(&s); + let fields_cnt = setup_matcher(&mut r); + + let mut ctx = Context::new(&r); + assert!(ctx.values.len() == fields_cnt); + assert_eq!(ctx.values, vec![None; fields_cnt]); + // access value with out of bound index + assert_eq!(ctx.value_of(0), None); + + // add value in bound + ctx.add_value("http.path.segments.1", Value::String("foo".to_string())); + assert_eq!(ctx.value_of(0), None); + assert_eq!(ctx.value_of(1).unwrap().len(), 1); + assert_eq!( + ctx.value_of(1).unwrap(), + vec![Value::String("foo".to_string())].as_slice() + ); + + // reset context keeps values capacity with all None + ctx.reset(); + assert!(ctx.values.len() == fields_cnt); + assert_eq!(ctx.values, vec![None; fields_cnt]); + + // reuse this context + ctx.add_value("http.path.segments.0", Value::String("bar".to_string())); + ctx.add_value("http.path.segments.0", Value::String("foo".to_string())); + assert!(ctx.values.len() == fields_cnt); + assert_eq!(ctx.value_of(0).unwrap().len(), 2); + assert_eq!( + ctx.value_of(0).unwrap(), + vec![ + Value::String("bar".to_string()), + Value::String("foo".to_string()) + ] + .as_slice() + ); + } + + #[test] + fn test_context_by_index() { + let mut s = Schema::default(); + s.add_field("http.path.segments.*", Type::String); + let mut r = Router::new(&s); + let fields_cnt = setup_matcher(&mut r); + + let mut ctx = Context::new(&r); + assert!(ctx.values.len() == fields_cnt); + assert_eq!(ctx.values, vec![None; fields_cnt]); + // access value with out of bound index + assert_eq!(ctx.value_of(0), None); + + // add value in bound + ctx.add_value_by_index(1, Value::String("foo".to_string())); + assert_eq!(ctx.value_of(0), None); + assert_eq!(ctx.value_of(1).unwrap().len(), 1); + assert_eq!( + ctx.value_of(1).unwrap(), + vec![Value::String("foo".to_string())].as_slice() + ); + + // reset context keeps values capacity with all None + ctx.reset(); + assert!(ctx.values.len() == fields_cnt); + assert_eq!(ctx.values, vec![None; fields_cnt]); + + // reuse this context + ctx.add_value_by_index(0, Value::String("bar".to_string())); + ctx.add_value_by_index(0, Value::String("foo".to_string())); + assert!(ctx.values.len() == fields_cnt); + assert_eq!(ctx.value_of(0).unwrap().len(), 2); + assert_eq!( + ctx.value_of(0).unwrap(), + vec![ + Value::String("bar".to_string()), + Value::String("foo".to_string()) + ] + .as_slice() + ); + } +} diff --git a/src/ffi.rs b/src/ffi.rs index 2a4aa8c2..e4dc4e7f 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -319,21 +319,26 @@ pub unsafe extern "C" fn router_get_fields( router: &Router, fields: *mut *const u8, fields_len: *mut usize, + indexes: *mut usize, ) -> usize { if !fields.is_null() { assert!(!fields_len.is_null()); - assert!(*fields_len >= router.fields.len()); + assert!(*fields_len >= router.fields.map.len()); let fields = from_raw_parts_mut(fields, *fields_len); + let indexes = from_raw_parts_mut(indexes, *fields_len); let fields_len = from_raw_parts_mut(fields_len, *fields_len); - for (i, k) in router.fields.keys().enumerate() { - fields[i] = k.as_bytes().as_ptr(); - fields_len[i] = k.len() + let mut i = 0; + for (field, pos) in router.fields.map.iter() { + fields[i] = field.as_bytes().as_ptr(); + fields_len[i] = field.len(); + indexes[i] = *pos; + i += 1; } } - router.fields.len() + router.fields.map.len() } /// Allocate a new context object associated with the schema. @@ -348,8 +353,8 @@ pub unsafe extern "C" fn router_get_fields( /// /// - `schema` must be a valid pointer returned by [`schema_new`]. #[no_mangle] -pub unsafe extern "C" fn context_new(schema: &Schema) -> *mut Context { - Box::into_raw(Box::new(Context::new(schema))) +pub unsafe extern "C" fn context_new<'a>(router: &'a Router<'a>) -> *mut Context<'a> { + Box::into_raw(Box::new(Context::new(router))) } /// Deallocate the context object. @@ -414,9 +419,31 @@ pub unsafe extern "C" fn context_add_value( errbuf: *mut u8, errbuf_len: *mut usize, ) -> bool { - let field = ffi::CStr::from_ptr(field as *const c_char) - .to_str() - .unwrap(); + let field = ffi::CStr::from_ptr(field).to_str().unwrap(); + let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); + + let value: Result = value.try_into(); + if let Err(e) = value { + errbuf[..e.len()].copy_from_slice(e.as_bytes()); + unsafe { + *errbuf_len = e.len(); + } + return false; + } + + context.add_value(field, value.unwrap()); + + true +} + +#[no_mangle] +pub unsafe extern "C" fn context_add_value_by_index( + context: &mut Context, + index: usize, + value: &CValue, + errbuf: *mut u8, + errbuf_len: *mut usize, +) -> bool { let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); let value: Result = value.try_into(); @@ -427,8 +454,7 @@ pub unsafe extern "C" fn context_add_value( return false; } - context.add_value(field, value.unwrap()); - + context.add_value_by_index(index, value.unwrap()); true } diff --git a/src/interpreter.rs b/src/interpreter.rs index 2031fe87..ce574f41 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -7,7 +7,7 @@ pub trait Execute { impl Execute for Predicate { fn execute(&self, ctx: &mut Context, m: &mut Match) -> bool { - let lhs_values = match ctx.value_of(&self.lhs.var_name) { + let lhs_values = match ctx.value_of(self.lhs.index) { None => return false, Some(v) => v, }; @@ -261,18 +261,28 @@ impl Execute for Predicate { #[test] fn test_predicate() { use crate::ast; + use crate::router::Router; use crate::schema; + use uuid::Uuid; let mut mat = Match::new(); let mut schema = schema::Schema::default(); schema.add_field("my_key", ast::Type::String); - let mut ctx = Context::new(&schema); + let mut r = Router::new(&schema); + // expression here is not practical, just used to setup context + assert!(r + .add_matcher(1, Uuid::new_v4(), r#"my_key=="whatever""#) + .is_ok()); + + let mut ctx = Context::new(&r); + let field_index: usize = 0; // check when value list is empty // check if all values match starts_with foo -- should be false let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![], }, rhs: Value::String("foo".to_string()), @@ -285,6 +295,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![], }, rhs: Value::String("foo".to_string()), @@ -302,13 +313,14 @@ fn test_predicate() { ]; for v in lhs_values { - ctx.add_value("my_key", v); + ctx.add_value_by_index(field_index, v); } // check if all values match starts_with foo -- should be true let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![], }, rhs: Value::String("foo".to_string()), @@ -321,6 +333,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![], }, rhs: Value::String("foo".to_string()), @@ -333,6 +346,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![ast::LhsTransformations::Any], }, rhs: Value::String("foo".to_string()), @@ -345,6 +359,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![ast::LhsTransformations::Any], }, rhs: Value::String("foo".to_string()), @@ -357,6 +372,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![ast::LhsTransformations::Any], }, rhs: Value::String("nar".to_string()), @@ -369,6 +385,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![ast::LhsTransformations::Any], }, rhs: Value::String("".to_string()), @@ -381,6 +398,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![ast::LhsTransformations::Any], }, rhs: Value::String("".to_string()), @@ -393,6 +411,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![ast::LhsTransformations::Any], }, rhs: Value::String("ob".to_string()), @@ -405,6 +424,7 @@ fn test_predicate() { let p = Predicate { lhs: ast::Lhs { var_name: "my_key".to_string(), + index: field_index, transformations: vec![ast::LhsTransformations::Any], }, rhs: Value::String("ok".to_string()), diff --git a/src/parser.rs b/src/parser.rs index 6d59d860..0e3a6986 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -88,6 +88,7 @@ fn parse_lhs(pair: Pair) -> ParseResult { let var = parse_ident(pair)?; Lhs { var_name: var, + index: Default::default(), transformations: Vec::new(), } } diff --git a/src/router.rs b/src/router.rs index 2c94aaa2..13445847 100644 --- a/src/router.rs +++ b/src/router.rs @@ -10,10 +10,16 @@ use uuid::Uuid; #[derive(PartialEq, Eq, PartialOrd, Ord)] struct MatcherKey(usize, Uuid); +pub struct Fields { + pub list: Vec>, // fileds list of tuple(name, count) + pub slots: Vec, // slots in list to be reused + pub map: HashMap, // 'name' to 'list index' maping +} + pub struct Router<'a> { schema: &'a Schema, matchers: BTreeMap, - pub fields: HashMap, + pub fields: Fields, } impl<'a> Router<'a> { @@ -21,7 +27,11 @@ impl<'a> Router<'a> { Self { schema, matchers: BTreeMap::new(), - fields: HashMap::new(), + fields: Fields { + list: Vec::new(), + slots: Vec::new(), + map: HashMap::new(), + }, } } @@ -31,10 +41,10 @@ impl<'a> Router<'a> { if self.matchers.contains_key(&key) { return Err("UUID already exists".to_string()); } - + // lhs's index maybe changed in `ast.add_to_counter` let ast = parse(atc).map_err(|e| e.to_string())?; ast.validate(self.schema)?; - let cir = ast.translate(); + let mut cir = ast.translate(); cir.add_to_counter(&mut self.fields); assert!(self.matchers.insert(key, cir).is_none()); @@ -44,8 +54,8 @@ impl<'a> Router<'a> { pub fn remove_matcher(&mut self, priority: usize, uuid: Uuid) -> bool { let key = MatcherKey(priority, uuid); - if let Some(cir) = self.matchers.remove(&key) { - cir.remove_from_counter(&mut self.fields); + if let Some(mut ast) = self.matchers.remove(&key) { + ast.remove_from_counter(&mut self.fields); return true; } @@ -65,4 +75,205 @@ impl<'a> Router<'a> { false } + + pub fn schema(&self) -> &Schema { + &self.schema + } +} + +#[cfg(test)] +mod tests { + use std::{ + cmp::max, + collections::HashMap, + net::{IpAddr, Ipv4Addr}, + }; + + use uuid::Uuid; + + use crate::{ + ast::{Type, Value}, + cir::{get_predicates, CirProgram}, + context::Context, + router::Router, + schema::Schema, + }; + + type ContextValues<'a> = HashMap<&'a str, Value>; + + fn setup_matcher(r: &mut Router, priority: usize, expression: &str) -> (Uuid, usize) { + let id = Uuid::new_v4(); + r.add_matcher(priority, id, expression) + .ok() + .expect("failed to addd matcher"); + (id, priority) + } + + fn init_context<'a>(r: &'a Router, ctx_values: &'a ContextValues<'a>) -> Context<'a> { + let mut ctx = Context::new(r); + for (i, v) in r.fields.list.iter().enumerate() { + if v.is_none() { + continue; + } + let key = &v.as_ref().unwrap().0; + if ctx_values.contains_key(key.as_str()) { + ctx.add_value_by_index(i, ctx_values.get(key.as_str()).unwrap().clone()); + } + } + ctx + } + + fn is_index_match(cir: &CirProgram, rt: &Router) -> bool { + let predicates = get_predicates(cir); + for p in predicates { + if rt.fields.list[p.lhs.index].as_ref().unwrap().0 == p.lhs.var_name + && *rt.fields.map.get(&p.lhs.var_name).unwrap() == p.lhs.index + { + continue; + } + return false; + } + true + } + + fn validate_index(r: &Router) -> bool { + for (_, e) in r.matchers.iter() { + if !is_index_match(e, r) { + return false; + } + } + true + } + + #[test] + fn test_router_execution() { + // init schema + let mut s = Schema::default(); + s.add_field("http.host", Type::String); + s.add_field("net.dst.port", Type::Int); + s.add_field("net.src.ip", Type::IpAddr); + + // init router + let mut r = Router::new(&s); + assert!(r.fields.list.len() == 0); + assert!(validate_index(&r)); + + // add matchers + let (id_0, pri_0) = setup_matcher(&mut r, 99, r#"http.host == "example.com""#); + let (id_1, pri_1) = + setup_matcher(&mut r, 98, r#"net.dst.port == 8443 || net.dst.port == 443"#); + let (id_2, pri_2) = setup_matcher(&mut r, 97, r#"net.src.ip == 192.168.1.1"#); + assert!(r.fields.list.len() == 3); + assert!(r.fields.slots.len() == 0); + assert!(validate_index(&r)); + + // mock context values + let mut ctx_values = HashMap::from([ + ("http.host", Value::String("example.com".to_string())), + ( + "net.src.ip", + Value::IpAddr(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2))), + ), + ]); + let mut ctx = init_context(&r, &ctx_values); + + // match the first matcher + let res = r.execute(&mut ctx); + assert!(res); + + // delete matcher, no field match now + r.remove_matcher(pri_0, id_0); + assert!(r.fields.list.len() == 3); + assert!(r.fields.slots.len() == 1); + assert!(validate_index(&r)); + ctx = init_context(&r, &ctx_values); + assert!(!r.execute(&mut ctx)); + + // context value change, match again + *ctx_values.get_mut("net.src.ip").unwrap() = + Value::IpAddr(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))); + ctx = init_context(&r, &ctx_values); + assert!(r.execute(&mut ctx)); + + // delete all matchers + r.remove_matcher(pri_1, id_1); + r.remove_matcher(pri_2, id_2); + assert!(r.fields.list.len() == 3); + assert!(r.fields.slots.len() == 3); + assert!(validate_index(&r)); + ctx = init_context(&r, &ctx_values); + assert!(!r.execute(&mut ctx)); + + // add a new matcher + let (_, _) = setup_matcher(&mut r, 96, r#"net.src.ip == 192.168.1.1"#); + assert!(r.fields.list.len() == 3); + assert!(r.fields.slots.len() == 2); + assert!(validate_index(&r)); + ctx = init_context(&r, &ctx_values); + assert!(r.execute(&mut ctx)); + } + + #[test] + fn test_fields_list() { + let mut s = Schema::default(); + s.add_field("http.path.segments.*", Type::String); + let mut r = Router::new(&s); + let i_max = 1000; + let mut ids = vec![]; + for i in 0..i_max { + let id: Uuid = Uuid::new_v4(); + let exp = format!(r#"http.path.segments.{} == "/bar""#, i.to_string()); + let pri = i; + assert!(r.add_matcher(pri, id, exp.as_str()).is_ok()); + assert!(r.fields.list.len() == i + 1); + assert!(r.fields.slots.len() == 0); + assert!(r.fields.map.len() == i + 1); + ids.push((pri, id)); + } + + // delete 100 fields + let mut valid_cnt = i_max; + for (idx, id) in &ids[100..200] { + let pri = idx; + assert!(r.remove_matcher(*pri, *id)); + valid_cnt -= 1; + assert!(r.fields.list.len() == i_max); + assert!(r.fields.slots.len() == i_max - valid_cnt); + assert!(r.fields.map.len() == valid_cnt); + } + + // deleted fields leave None in fields list + for i in 100..200 { + assert!(r.fields.list[i] == None); + } + + // adds 200 fields back + let fields_len = r.fields.list.len(); + let mut slot_cnt = r.fields.slots.len(); + for i in 0..200 { + let id: Uuid = Uuid::new_v4(); + let exp = format!( + r#"http.path.segments.{} == "/bar""#, + (i_max + i).to_string() + ); + let pri = i; + if slot_cnt > 0 { + slot_cnt -= 1; + } + assert!(r.add_matcher(pri, id, exp.as_str()).is_ok()); + assert!(r.fields.list.len() == max(fields_len, r.fields.map.len())); + assert!(r.fields.slots.len() == slot_cnt); + assert!(r.fields.map.len() == r.fields.list.len() - slot_cnt); + } + + // 100 slot deleted before should be reused + for i in 100..200 { + assert!(r.fields.list[i].is_some()); + } + + // 100 slot newly added should be valid + for i in i_max..i_max + 100 { + assert!(r.fields.list[i].is_some()); + } + } } diff --git a/src/semantics.rs b/src/semantics.rs index 9c4dc5b8..8467e7d7 100644 --- a/src/semantics.rs +++ b/src/semantics.rs @@ -1,6 +1,6 @@ use crate::ast::{BinaryOperator, Expression, LogicalExpression, Type, Value}; +use crate::router::Fields; use crate::schema::Schema; -use std::collections::HashMap; type ValidationResult = Result<(), String>; @@ -9,8 +9,78 @@ pub trait Validate { } pub trait FieldCounter { - fn add_to_counter(&self, map: &mut HashMap); - fn remove_from_counter(&self, map: &mut HashMap); + fn add_to_counter(&mut self, fields: &mut Fields); + fn remove_from_counter(&mut self, fields: &mut Fields); +} + +impl FieldCounter for Expression { + fn add_to_counter(&mut self, fields: &mut Fields) { + match self { + Expression::Logical(l) => match l.as_mut() { + LogicalExpression::And(l, r) => { + l.add_to_counter(fields); + r.add_to_counter(fields); + } + LogicalExpression::Or(l, r) => { + l.add_to_counter(fields); + r.add_to_counter(fields); + } + LogicalExpression::Not(r) => { + r.add_to_counter(fields); + } + }, + Expression::Predicate(p) => { + // 1. fields: increment counter for field + // 2. lhs: assign field index to the LHS + // 3. map: maintain the fields map: {field_name : field_index} + if let Some(index) = fields.map.get(&p.lhs.var_name) { + fields.list[*index].as_mut().unwrap().1 += 1; + p.lhs.index = *index; + } else { + // reuse slots in queue if possible + let new_idx: usize; + if fields.slots.is_empty() { + fields.list.push(Some((p.lhs.var_name.clone(), 1))); + new_idx = fields.list.len() - 1; + } else { + new_idx = fields.slots.pop().unwrap(); + fields.list[new_idx] = Some((p.lhs.var_name.clone(), 1)); + } + fields.map.insert(p.lhs.var_name.clone(), new_idx); + p.lhs.index = new_idx; + } + } + } + } + + fn remove_from_counter(&mut self, fields: &mut Fields) { + match self { + Expression::Logical(l) => match l.as_mut() { + LogicalExpression::And(l, r) => { + l.remove_from_counter(fields); + r.remove_from_counter(fields); + } + LogicalExpression::Or(l, r) => { + l.remove_from_counter(fields); + r.remove_from_counter(fields); + } + LogicalExpression::Not(r) => { + r.remove_from_counter(fields); + } + }, + Expression::Predicate(p) => { + let index: usize = p.lhs.index; + // decrement counter of field + fields.list[index].as_mut().unwrap().1 -= 1; + // for field removing, reserve the slot for resue and remove it in map + if fields.list[index].as_mut().unwrap().1 == 0 { + fields.list[index] = None; + fields.slots.push(index); + assert!(fields.map.remove(&p.lhs.var_name).is_some()); + } + } + } + } } impl Validate for Expression { diff --git a/t/01-sanity.t b/t/01-sanity.t index 70f007f9..d334863f 100644 --- a/t/01-sanity.t +++ b/t/01-sanity.t @@ -39,7 +39,7 @@ __DATA__ assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path ^= \"/foo\" && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo/bar") c:add_value("tcp.port", 80) @@ -49,6 +49,18 @@ __DATA__ local uuid, prefix = c:get_result("http.path") ngx.say(uuid) ngx.say(prefix) + + -- context set by index + c:reset() + c:add_value_by_index("http.path", "/foo/bar", 0) + c:add_value_by_index("tcp.port", 80, 1) + + matched = r:execute(c) + ngx.say(matched) + + uuid, prefix = c:get_result("http.path") + ngx.say(uuid) + ngx.say(prefix) } } --- request @@ -57,6 +69,9 @@ GET /t true a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c /foo +true +a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c +/foo --- no_error_log [error] [warn] @@ -84,7 +99,7 @@ a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150d", "http.path ^= \"/\"")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo/bar") c:add_value("tcp.port", 80) @@ -94,6 +109,17 @@ a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c local uuid, prefix = c:get_result("http.path") ngx.say("uuid = " .. uuid .. " prefix = " .. prefix) + + -- context set by index + c:reset() + c:add_value_by_index("http.path", "/foo/bar", 0) + c:add_value_by_index("tcp.port", 80, 1) + + matched = r:execute(c) + ngx.say(matched) + + uuid, prefix = c:get_result("http.path") + ngx.say("uuid = " .. uuid .. " prefix = " .. prefix) } } --- request @@ -101,6 +127,8 @@ GET /t --- response_body true uuid = a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c prefix = /foo +true +uuid = a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c prefix = /foo --- no_error_log [error] [warn] @@ -126,7 +154,7 @@ uuid = a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c prefix = /foo assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path ^= \"/foo\" && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo/bar") c:add_value("tcp.port", 80) @@ -138,8 +166,8 @@ uuid = a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c prefix = /foo ngx.say(prefix) assert(r:remove_matcher("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c")) - - c = context.new(s) + + c = context.new(r) c:add_value("http.path", "/foo/bar") c:add_value("tcp.port", 80) @@ -214,7 +242,7 @@ nil --> 1:11 assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path ^= \"/foo\" && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo/bar") c:add_value("tcp.port", 80) diff --git a/t/02-bugs.t b/t/02-bugs.t index c10aae45..e74bf2e7 100644 --- a/t/02-bugs.t +++ b/t/02-bugs.t @@ -28,10 +28,15 @@ __DATA__ content_by_lua_block { local schema = require("resty.router.schema") local context = require("resty.router.context") + local router = require("resty.router.router") local s = schema.new() s:add_field("http.path", "String") + s:add_field("tcp.port", "Int") + local r = router.new(s) + assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", + "http.path ^= \"/foo\" && tcp.port == 80")) local BAD_UTF8 = { "\x80", @@ -39,7 +44,7 @@ __DATA__ "\xfc\x80\x80\x80\x80\xaf", } - local c = context.new(s) + local c = context.new(r) for _, v in ipairs(BAD_UTF8) do local ok, err = c:add_value("http.path", v) ngx.say(err) @@ -66,12 +71,15 @@ invalid utf-8 sequence of 1 bytes from index 0 content_by_lua_block { local schema = require("resty.router.schema") local context = require("resty.router.context") + local router = require("resty.router.router") local s = schema.new() s:add_field("http.path", "String") + s:add_field("tcp.port", "Int") + local r = router.new(s) - local c = context.new(s) + local c = context.new(r) assert(c:add_value("http.path", "\x00")) ngx.say("ok") } @@ -146,7 +154,7 @@ ok assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.body =^ \"world\"")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.body", "hello\x00world") local matched = r:execute(c) diff --git a/t/03-contains.t b/t/03-contains.t index 88e00be3..c419551b 100644 --- a/t/03-contains.t +++ b/t/03-contains.t @@ -39,7 +39,7 @@ __DATA__ assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path contains \"keyword\" && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo/keyword/bar") c:add_value("tcp.port", 80) @@ -83,7 +83,7 @@ nil assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path contains \"keyword\" && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo/bar") c:add_value("tcp.port", 80) diff --git a/t/04-rawstr.t b/t/04-rawstr.t index 80cc1d31..eada3558 100644 --- a/t/04-rawstr.t +++ b/t/04-rawstr.t @@ -39,7 +39,7 @@ __DATA__ assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path ^= r#\"/foo\"# && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo/bar") c:add_value("tcp.port", 80) @@ -82,7 +82,7 @@ a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path ^= r#\"/foo\"\'\"# && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo\"\'/bar") c:add_value("tcp.port", 80) @@ -126,7 +126,7 @@ a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path ~ r#\"^/\\d+/test$\"# && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/123/test") c:add_value("tcp.port", 80) @@ -169,7 +169,7 @@ a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.path ~ r#\"^/\\D+/test$\"# && tcp.port == 80")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/123/test") c:add_value("tcp.port", 80) diff --git a/t/05-equals.t b/t/05-equals.t index c41de0a8..23878a31 100644 --- a/t/05-equals.t +++ b/t/05-equals.t @@ -38,7 +38,7 @@ __DATA__ assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.headers.foo == \"bar\"")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.headers.foo", "bar") c:add_value("http.headers.foo", "bar") c:add_value("http.headers.foo", "bar") @@ -82,7 +82,7 @@ bar assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.headers.foo == \"bar\"")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.headers.foo", "bar") c:add_value("http.headers.foo", "bar") c:add_value("http.headers.foo", "barX") @@ -124,7 +124,7 @@ nil assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "http.headers.foo == \"bar\"")) - local c = context.new(s) + local c = context.new(r) local matched = r:execute(c) ngx.say(matched) diff --git a/t/06-validate.t b/t/06-validate.t index a4a9b2a7..ef63990a 100644 --- a/t/06-validate.t +++ b/t/06-validate.t @@ -37,8 +37,11 @@ __DATA__ ngx.say(type(r)) ngx.say(err) - ngx.say(#r) - ngx.say(r[1]) + + for k, v in pairs(r) do + ngx.say(k) + ngx.say(tonumber(v)) + end } } --- request @@ -46,8 +49,8 @@ GET /t --- response_body table nil -1 http.headers.foo +0 --- no_error_log [error] [warn] diff --git a/t/07-in_notin.t b/t/07-in_notin.t index d925072c..b4d3ce49 100644 --- a/t/07-in_notin.t +++ b/t/07-in_notin.t @@ -76,13 +76,13 @@ nilIn/NotIn operators only supports IP in CIDR assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", "l3.ip in 192.168.12.0/24")) - local c = context.new(s) + local c = context.new(r) c:add_value("l3.ip", "192.168.12.1") local matched = r:execute(c) ngx.say(matched) - c = context.new(s) + c = context.new(r) c:add_value("l3.ip", "192.168.1.1") local matched = r:execute(c) diff --git a/t/08-equals.t b/t/08-equals.t index 3f8090a6..e6039848 100644 --- a/t/08-equals.t +++ b/t/08-equals.t @@ -40,14 +40,14 @@ __DATA__ assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-8aa5583d150c", "net.port != 8000")) - local c = context.new(s) + local c = context.new(r) c:add_value("net.port", 8000) local matched = r:execute(c) ngx.say(matched) ngx.say(c:get_result()) - c = context.new(s) + c = context.new(r) c:add_value("net.port", 8001) matched = r:execute(c) @@ -88,14 +88,14 @@ a921a9aa-ec0e-4cf3-a6cc-8aa5583d150cnilnil assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-8aa5583d150c", "http.path != \"/foo\"")) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/foo") local matched = r:execute(c) ngx.say(matched) ngx.say(c:get_result()) - c = context.new(s) + c = context.new(r) c:add_value("http.path", "/foo1") matched = r:execute(c) @@ -136,14 +136,14 @@ a921a9aa-ec0e-4cf3-a6cc-8aa5583d150cnilnil assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-8aa5583d150c", "net.ip != 192.168.1.1")) - local c = context.new(s) + local c = context.new(r) c:add_value("net.ip", "192.168.1.1") local matched = r:execute(c) ngx.say(matched) ngx.say(c:get_result()) - c = context.new(s) + c = context.new(r) c:add_value("net.ip", "192.168.1.2") matched = r:execute(c) diff --git a/t/09-not.t b/t/09-not.t index 518ebbdd..142ae9c2 100644 --- a/t/09-not.t +++ b/t/09-not.t @@ -38,7 +38,7 @@ __DATA__ assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", [[!(http.path ^= "/abc")]])) - local c = context.new(s) + local c = context.new(r) c:add_value("http.path", "/abc/d") local matched = r:execute(c)