Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/ast.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::schema::Schema;
use cidr::IpCidr;
use pest::error::LineColLocation;
use regex::Regex;
use std::net::IpAddr;

Expand All @@ -13,12 +14,46 @@ pub enum Expression {
Predicate(Predicate),
}

#[derive(Debug)]
pub struct Span {
pub start_line: usize,
pub end_line: usize,
pub start_col: usize,
pub end_col: usize,
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LocationedExpression {
pub expression: Expression,
pub span: LineColLocation,
}

impl LocationedExpression {
pub fn new(expression: Expression, span: LineColLocation) -> Self {
LocationedExpression { expression, span }
}
}

impl From<LocationedExpression> for Expression {
fn from(loc_expr: LocationedExpression) -> Self {
loc_expr.expression
}
}

impl From<Expression> for LocationedExpression {
fn from(expr: Expression) -> Self {
// unknown location
LocationedExpression::new(expr, LineColLocation::Pos((0, 0)))
}
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub enum LogicalExpression {
And(Expression, Expression),
Or(Expression, Expression),
Not(Expression),
And(LocationedExpression, LocationedExpression),
Or(LocationedExpression, LocationedExpression),
Not(LocationedExpression),
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -152,6 +187,12 @@ mod tests {
}
}

impl fmt::Display for LocationedExpression {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.expression)
}
}

impl fmt::Display for LogicalExpression {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
Expand Down
34 changes: 19 additions & 15 deletions src/cir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,43 +80,43 @@ fn cir_translate_helper(exp: &Expression, cir: &mut CirProgram) -> usize {
match exp {
Expression::Logical(logic_exp) => match logic_exp.as_ref() {
LogicalExpression::And(l, r) => {
let left = match l {
let left = match &l.expression {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(l, cir))
CirOperand::Index(cir_translate_helper(&l.expression, cir))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};

let right = match r {
let right = match &r.expression {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(&r.expression, cir))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
let and_ins = AndIns { left, right };
cir.instructions.push(CirInstruction::AndIns(and_ins));
}
LogicalExpression::Or(l, r) => {
let left = match l {
let left = match &l.expression {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(l, cir))
CirOperand::Index(cir_translate_helper(&l.expression, cir))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};

let right = match r {
let right = match &r.expression {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(&r.expression, cir))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
let or_ins = OrIns { left, right };
cir.instructions.push(CirInstruction::OrIns(or_ins));
}
LogicalExpression::Not(r) => {
let right: CirOperand = match r {
let right: CirOperand = match &r.expression {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(&r.expression, cir))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
Expand Down Expand Up @@ -267,9 +267,13 @@ mod tests {
use crate::ast::{Expression, LogicalExpression};
match self {
Expression::Logical(l) => match l.as_ref() {
LogicalExpression::And(l, r) => l.execute(ctx, m) && r.execute(ctx, m),
LogicalExpression::Or(l, r) => l.execute(ctx, m) || r.execute(ctx, m),
LogicalExpression::Not(r) => !r.execute(ctx, m),
LogicalExpression::And(l, r) => {
l.expression.execute(ctx, m) && r.expression.execute(ctx, m)
}
LogicalExpression::Or(l, r) => {
l.expression.execute(ctx, m) || r.expression.execute(ctx, m)
}
LogicalExpression::Not(r) => !r.expression.execute(ctx, m),
},
Expression::Predicate(p) => p.execute(ctx, m),
}
Expand Down Expand Up @@ -302,9 +306,9 @@ mod tests {
.map_err(|e| e.to_string())
.unwrap();
let mut mat = Match::new();
let ast_result = ast.execute(&mut context, &mut mat);
let ast_result = ast.expression.execute(&mut context, &mut mat);

let cir_result = ast.translate().execute(&mut context, &mut mat);
let cir_result = ast.expression.translate().execute(&mut context, &mut mat);
assert_eq!(ast_result, cir_result);
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/ffi/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ impl<'a> Iterator for PredicateIterator<'a> {
match expr {
Expression::Logical(l) => match l.as_ref() {
LogicalExpression::And(l, r) | LogicalExpression::Or(l, r) => {
self.stack.push(l);
self.stack.push(r);
self.stack.push(&l.expression);
self.stack.push(&r.expression);
}
LogicalExpression::Not(r) => {
self.stack.push(r);
self.stack.push(&r.expression);
}
},
Expression::Predicate(p) => return Some(p),
Expand Down Expand Up @@ -189,7 +189,7 @@ pub unsafe extern "C" fn expression_validate(
let mut fields_buf_ptr = fields_buf;
*fields_total = 0;

for pred in ast.iter_predicates() {
for pred in ast.expression.iter_predicates() {
ops |= BinaryOperatorFlags::from(&pred.op);

let field = pred.lhs.var_name.as_str();
Expand Down
52 changes: 38 additions & 14 deletions src/parser.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
extern crate pest;

use crate::ast::{
BinaryOperator, Expression, Lhs, LhsTransformations, LogicalExpression, Predicate, Value,
BinaryOperator, Expression, Lhs, LhsTransformations, LocationedExpression, LogicalExpression,
Predicate, Value,
};
use cidr::{IpCidr, Ipv4Cidr, Ipv6Cidr};
use pest::error::Error as ParseError;
use pest::error::ErrorVariant;
use pest::error::{Error as ParseError, LineColLocation};
use pest::iterators::Pair;
use pest::pratt_parser::Assoc as AssocNew;
use pest::pratt_parser::{Op, PrattParser};
Expand Down Expand Up @@ -61,7 +62,7 @@ impl ATCParser {
}
// matcher = { SOI ~ expression ~ EOI }
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
fn parse_matcher(&mut self, source: &str) -> ParseResult<Expression> {
fn parse_matcher(&mut self, source: &str) -> ParseResult<LocationedExpression> {
let pairs = ATCParser::parse(Rule::matcher, source)?;
let expr_pair = pairs.peek().unwrap().into_inner().peek().unwrap();
let rule = expr_pair.as_rule();
Expand All @@ -72,6 +73,10 @@ impl ATCParser {
}
}

fn non_ref_span_from_pair(pair: &Pair<Rule>) -> LineColLocation {
pair.as_span().into()
}

#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
fn parse_ident(pair: Pair<Rule>) -> ParseResult<String> {
Ok(pair.as_str().into())
Expand Down Expand Up @@ -289,53 +294,72 @@ fn parse_binary_operator(pair: Pair<Rule>) -> BinaryOperator {
fn parse_parenthesised_expression(
pair: Pair<Rule>,
pratt: &PrattParser<Rule>,
) -> ParseResult<Expression> {
) -> ParseResult<LocationedExpression> {
let mut pairs = pair.into_inner();
let pair = pairs.next().unwrap();
let rule = pair.as_rule();
match rule {
Rule::expression => parse_expression(pair, pratt),
Rule::not_op => Ok(Expression::Logical(Box::new(LogicalExpression::Not(
parse_expression(pairs.next().unwrap(), pratt)?,
)))),
Rule::not_op => Ok(LocationedExpression::new(
Expression::Logical(Box::new(LogicalExpression::Not(parse_expression(
pairs.next().unwrap(),
pratt,
)?))),
non_ref_span_from_pair(&pair),
)),
_ => unreachable!(),
}
}

// term = { predicate | parenthesised_expression }
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
fn parse_term(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<Expression> {
fn parse_term(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<LocationedExpression> {
let span = non_ref_span_from_pair(&pair);
let pairs = pair.into_inner();
let inner_rule = pairs.peek().unwrap();
let rule = inner_rule.as_rule();
match rule {
Rule::predicate => Ok(Expression::Predicate(parse_predicate(inner_rule)?)),
Rule::predicate => Ok(LocationedExpression::new(
Expression::Predicate(parse_predicate(inner_rule)?),
span,
)),
Rule::parenthesised_expression => parse_parenthesised_expression(inner_rule, pratt),
_ => unreachable!(),
}
}

// expression = { term ~ ( logical_operator ~ term )* }
#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
fn parse_expression(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<Expression> {
fn parse_expression(
pair: Pair<Rule>,
pratt: &PrattParser<Rule>,
) -> ParseResult<LocationedExpression> {
let span = non_ref_span_from_pair(&pair);
let pairs = pair.into_inner();
pratt
.map_primary(|operand| match operand.as_rule() {
Rule::term => parse_term(operand, pratt),
_ => unreachable!(),
})
.map_infix(|lhs, op, rhs| {
.map_infix(move |lhs, op, rhs| {
let span = span.clone();
Ok(match op.as_rule() {
Rule::and_op => Expression::Logical(Box::new(LogicalExpression::And(lhs?, rhs?))),
Rule::or_op => Expression::Logical(Box::new(LogicalExpression::Or(lhs?, rhs?))),
Rule::and_op => LocationedExpression::new(
Expression::Logical(Box::new(LogicalExpression::And(lhs?.into(), rhs?.into()))),
span,
),
Rule::or_op => LocationedExpression::new(
Expression::Logical(Box::new(LogicalExpression::Or(lhs?.into(), rhs?.into()))),
span,
),
_ => unreachable!(),
})
})
.parse(pairs)
}

#[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path
pub fn parse(source: &str) -> ParseResult<Expression> {
pub fn parse(source: &str) -> ParseResult<LocationedExpression> {
ATCParser::new().parse_matcher(source)
}

Expand Down
6 changes: 4 additions & 2 deletions src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use crate::interpreter::Execute;
use crate::parser::parse;
use crate::schema::Schema;
use crate::semantics::{FieldCounter, Validate};
use pest::error::LineColLocation;
use std::collections::{BTreeMap, HashMap};
use std::f32::consts::E;
use uuid::Uuid;

#[derive(PartialEq, Eq, PartialOrd, Ord)]
Expand Down Expand Up @@ -33,8 +35,8 @@ impl<'a> Router<'a> {
}

let ast = parse(atc).map_err(|e| e.to_string())?;
ast.validate(self.schema)?;
let cir = ast.translate();
ast.validate(self.schema).map_err(|e| e.to_string())?;
let cir = ast.expression.translate();
cir.add_to_counter(&mut self.fields);
assert!(self.matchers.insert(key, cir).is_none());

Expand Down
Loading