Skip to content

Commit b2fe7b7

Browse files
committed
sql: clean up types module
1 parent 83cd189 commit b2fe7b7

File tree

6 files changed

+254
-228
lines changed

6 files changed

+254
-228
lines changed

src/sql/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ mod tests {
305305
if tags.remove("cnf") {
306306
let cnf = expr.clone().into_cnf();
307307
assert_eq!(value, cnf.evaluate(None)?, "CNF result differs");
308-
write!(output, " ← {}", cnf.format_constant())?;
308+
write!(output, " ← {cnf}")?;
309309
}
310310

311311
// If requested, debug-dump the parsed expression.

src/sql/testscripts/expressions/func_sqrt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
!> sqrt(-1)
1212
> sqrt(0)
1313
---
14-
Error: invalid input: can't take negative square root
14+
Error: invalid input: can't take square root of -1
1515
0.0
1616

1717
# Floats work.

src/sql/testscripts/expressions/op_math_factorial

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ Error: invalid input: can't take factorial of 3.0
1616
Error: invalid input: can't take factorial of TRUE
1717
Error: invalid input: can't take factorial of 'a'
1818

19+
# 0 factorial is 1, but negative factorial errors.
20+
> -0!
21+
!> -1!
22+
---
23+
1
24+
Error: invalid input: can't take factorial of -1
25+
1926
# NULL yields null, infinity and NaN error.
2027
> NULL!
2128
!> INFINITY!

src/sql/types/expression.rs

Lines changed: 94 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::fmt::Display;
2+
13
use regex::Regex;
24
use serde::{Deserialize, Serialize};
35

@@ -7,8 +9,8 @@ use crate::error::Result;
79
use crate::sql::planner::Node;
810

911
/// An expression, made up of nested operations and values. Values are either
10-
/// constants or dynamic column references. Evaluates to a final value during
11-
/// query execution, using row values for column references.
12+
/// constants, or column references which are looked up in rows. Evaluated to a
13+
/// final value during query execution.
1214
///
1315
/// Since this is a recursive data structure, we have to box each child
1416
/// expression, which incurs a heap allocation per expression node. There are
@@ -17,57 +19,65 @@ use crate::sql::planner::Node;
1719
pub enum Expression {
1820
/// A constant value.
1921
Constant(Value),
20-
/// A column reference. Used as row index when evaluating expressions.
22+
/// A column reference. Looks up the value in a row during evaluation.
2123
Column(usize),
2224

23-
/// Logical AND of two booleans: a AND b.
25+
/// a AND b: logical AND of two booleans.
2426
And(Box<Expression>, Box<Expression>),
25-
/// Logical OR of two booleans: a OR b.
27+
/// a OR b: logical OR of two booleans.
2628
Or(Box<Expression>, Box<Expression>),
27-
/// Logical NOT of a boolean: NOT a.
29+
/// NOT a: logical NOT of a boolean.
2830
Not(Box<Expression>),
2931

30-
/// Equality comparison of two values: a = b.
32+
/// a = b: equality comparison of two values.
3133
Equal(Box<Expression>, Box<Expression>),
3234
/// Greater than comparison of two values: a > b.
3335
GreaterThan(Box<Expression>, Box<Expression>),
34-
/// Less than comparison of two values: a < b.
36+
/// a < b: less than comparison of two values.
3537
LessThan(Box<Expression>, Box<Expression>),
36-
/// Checks for the given value: IS NULL or IS NAN.
38+
/// a IS NULL or a IS NAN: checks for the given value.
3739
Is(Box<Expression>, Value),
3840

39-
/// Adds two numbers: a + b.
41+
/// a + b: adds two numbers.
4042
Add(Box<Expression>, Box<Expression>),
41-
/// Divides two numbers: a / b.
43+
/// a / b: divides two numbers.
4244
Divide(Box<Expression>, Box<Expression>),
43-
/// Exponentiates two numbers, i.e. a ^ b.
45+
/// a ^b: exponentiates two numbers.
4446
Exponentiate(Box<Expression>, Box<Expression>),
45-
/// Takes the factorial of a number: 4! = 4*3*2*1.
47+
/// a!: takes the factorial of a number (4! = 4*3*2*1).
4648
Factorial(Box<Expression>),
47-
/// The identify function, which simply returns the same number: +a.
49+
/// +a: the identify function, which simply returns the same number.
4850
Identity(Box<Expression>),
49-
/// Multiplies two numbers: a * b.
51+
/// a * b: multiplies two numbers.
5052
Multiply(Box<Expression>, Box<Expression>),
51-
/// Negates the given number: -a.
53+
/// -a: negates the given number.
5254
Negate(Box<Expression>),
53-
/// The remainder after dividing two numbers: a % b.
55+
/// a % b: the remainder after dividing two numbers.
5456
Remainder(Box<Expression>, Box<Expression>),
55-
/// Takes the square root of a number: √a.
57+
/// √a: takes the square root of a number.
5658
SquareRoot(Box<Expression>),
57-
/// Subtracts two numbers: a - b.
59+
/// a - b: subtracts two numbers.
5860
Subtract(Box<Expression>, Box<Expression>),
5961

60-
// Checks if a string matches a pattern: a LIKE b.
62+
// a LIKE b: checks if a string matches a pattern.
6163
Like(Box<Expression>, Box<Expression>),
6264
}
6365

66+
// NB: display can't look up column labels, and will print numeric column
67+
// indexes instead. Use Expression::format() to print with labels.
68+
impl Display for Expression {
69+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70+
write!(f, "{}", self.format(&Node::Nothing { columns: Vec::new() }))
71+
}
72+
}
73+
6474
impl Expression {
6575
/// Formats the expression, using the given plan node to look up labels for
66-
/// numeric column references.
76+
/// column references.
6777
pub fn format(&self, node: &Node) -> String {
6878
use Expression::*;
6979

70-
// Precedence levels, for grouping. Matches the parser precedence.
80+
// Precedence levels, for () grouping. Matches the parser precedence.
7181
fn precedence(expr: &Expression) -> u8 {
7282
match expr {
7383
Column(_) | Constant(_) | SquareRoot(_) => 11,
@@ -85,7 +95,7 @@ impl Expression {
8595
}
8696

8797
// Helper to format a boxed expression, grouping it with () if needed.
88-
let format = |expr: &Expression| {
98+
let fmt = |expr: &Expression| {
8999
let mut string = expr.format(node);
90100
if precedence(expr) < precedence(self) {
91101
string = format!("({string})");
@@ -100,51 +110,44 @@ impl Expression {
100110
label => format!("{label}"),
101111
},
102112

103-
And(lhs, rhs) => format!("{} AND {}", format(lhs), format(rhs)),
104-
Or(lhs, rhs) => format!("{} OR {}", format(lhs), format(rhs)),
105-
Not(expr) => format!("NOT {}", format(expr)),
113+
And(lhs, rhs) => format!("{} AND {}", fmt(lhs), fmt(rhs)),
114+
Or(lhs, rhs) => format!("{} OR {}", fmt(lhs), fmt(rhs)),
115+
Not(expr) => format!("NOT {}", fmt(expr)),
106116

107-
Equal(lhs, rhs) => format!("{} = {}", format(lhs), format(rhs)),
108-
GreaterThan(lhs, rhs) => format!("{} > {}", format(lhs), format(rhs)),
109-
LessThan(lhs, rhs) => format!("{} < {}", format(lhs), format(rhs)),
110-
Is(expr, Value::Null) => format!("{} IS NULL", format(expr)),
111-
Is(expr, Value::Float(f)) if f.is_nan() => format!("{} IS NAN", format(expr)),
117+
Equal(lhs, rhs) => format!("{} = {}", fmt(lhs), fmt(rhs)),
118+
GreaterThan(lhs, rhs) => format!("{} > {}", fmt(lhs), fmt(rhs)),
119+
LessThan(lhs, rhs) => format!("{} < {}", fmt(lhs), fmt(rhs)),
120+
Is(expr, Value::Null) => format!("{} IS NULL", fmt(expr)),
121+
Is(expr, Value::Float(f)) if f.is_nan() => format!("{} IS NAN", fmt(expr)),
112122
Is(_, v) => panic!("unexpected IS value {v}"),
113123

114-
Add(lhs, rhs) => format!("{} + {}", format(lhs), format(rhs)),
115-
Divide(lhs, rhs) => format!("{} / {}", format(lhs), format(rhs)),
116-
Exponentiate(lhs, rhs) => format!("{} ^ {}", format(lhs), format(rhs)),
117-
Factorial(expr) => format!("{}!", format(expr)),
118-
Identity(expr) => format(expr),
119-
Multiply(lhs, rhs) => format!("{} * {}", format(lhs), format(rhs)),
120-
Negate(expr) => format!("-{}", format(expr)),
121-
Remainder(lhs, rhs) => format!("{} % {}", format(lhs), format(rhs)),
122-
SquareRoot(expr) => format!("sqrt({})", format(expr)),
123-
Subtract(lhs, rhs) => format!("{} - {}", format(lhs), format(rhs)),
124-
125-
Like(lhs, rhs) => format!("{} LIKE {}", format(lhs), format(rhs)),
124+
Add(lhs, rhs) => format!("{} + {}", fmt(lhs), fmt(rhs)),
125+
Divide(lhs, rhs) => format!("{} / {}", fmt(lhs), fmt(rhs)),
126+
Exponentiate(lhs, rhs) => format!("{} ^ {}", fmt(lhs), fmt(rhs)),
127+
Factorial(expr) => format!("{}!", fmt(expr)),
128+
Identity(expr) => fmt(expr),
129+
Multiply(lhs, rhs) => format!("{} * {}", fmt(lhs), fmt(rhs)),
130+
Negate(expr) => format!("-{}", fmt(expr)),
131+
Remainder(lhs, rhs) => format!("{} % {}", fmt(lhs), fmt(rhs)),
132+
SquareRoot(expr) => format!("sqrt({})", fmt(expr)),
133+
Subtract(lhs, rhs) => format!("{} - {}", fmt(lhs), fmt(rhs)),
134+
135+
Like(lhs, rhs) => format!("{} LIKE {}", fmt(lhs), fmt(rhs)),
126136
}
127137
}
128138

129-
/// Formats a constant expression. Errors on column references.
130-
pub fn format_constant(&self) -> String {
131-
self.format(&Node::Nothing { columns: Vec::new() })
132-
}
133-
134-
/// Evaluates an expression, returning a value. Column references look up
135-
/// values in the given row. If None, any Column references will panic.
139+
/// Evaluates an expression, returning a constant value. Column references
140+
/// are looked up in the given row (or panic if the row is None).
136141
pub fn evaluate(&self, row: Option<&Row>) -> Result<Value> {
137142
use Value::*;
143+
138144
Ok(match self {
139145
// Constant values return themselves.
140146
Self::Constant(value) => value.clone(),
141147

142148
// Column references look up a row value. The planner ensures that
143149
// only constant expressions are evaluated without a row.
144-
Self::Column(index) => match row {
145-
Some(row) => row.get(*index).expect("short row").clone(),
146-
None => panic!("can't reference column {index} with constant evaluation"),
147-
},
150+
Self::Column(index) => row.and_then(|r| r.get(*index)).cloned().expect("invalid index"),
148151

149152
// Logical AND. Inputs must be boolean or NULL. NULLs generally
150153
// yield NULL, except the special case NULL AND false == false.
@@ -174,8 +177,8 @@ impl Expression {
174177
// Comparisons. Must be of same type, except floats and integers
175178
// which are interchangeable. NULLs yield NULL, NaNs yield NaN.
176179
//
177-
// Does not dispatch to Value.cmp() because sorting and comparisons
178-
// are different for f64 NaN and -0.0 values.
180+
// Does not dispatch to Value.cmp() because comparison and sorting
181+
// is different for Nulls and NaNs in SQL and code.
179182
#[allow(clippy::float_cmp)]
180183
Self::Equal(lhs, rhs) => match (lhs.evaluate(row)?, rhs.evaluate(row)?) {
181184
(Boolean(lhs), Boolean(rhs)) => Boolean(lhs == rhs),
@@ -222,18 +225,19 @@ impl Expression {
222225

223226
// Mathematical operations. Inputs must be numbers, but integers and
224227
// floats are interchangeable (float when mixed). NULLs yield NULL.
225-
// Errors on integer overflow, while floats yield infinity or NaN.
228+
// Errors on integer overflow, but floats yield infinity or NaN.
226229
Self::Add(lhs, rhs) => lhs.evaluate(row)?.checked_add(&rhs.evaluate(row)?)?,
227230
Self::Divide(lhs, rhs) => lhs.evaluate(row)?.checked_div(&rhs.evaluate(row)?)?,
228231
Self::Exponentiate(lhs, rhs) => lhs.evaluate(row)?.checked_pow(&rhs.evaluate(row)?)?,
229232
Self::Factorial(expr) => match expr.evaluate(row)? {
230-
Integer(i) if i < 0 => return errinput!("can't take factorial of negative number"),
231-
Integer(i) => (1..=i).try_fold(Integer(1), |p, i| p.checked_mul(&Integer(i)))?,
233+
Integer(i @ 0..) => {
234+
(1..=i).try_fold(Integer(1), |p, i| p.checked_mul(&Integer(i)))?
235+
}
232236
Null => Null,
233237
value => return errinput!("can't take factorial of {value}"),
234238
},
235239
Self::Identity(expr) => match expr.evaluate(row)? {
236-
v @ (Integer(_) | Float(_) | Null) => v,
240+
value @ (Integer(_) | Float(_) | Null) => value,
237241
expr => return errinput!("can't take the identity of {expr}"),
238242
},
239243
Self::Multiply(lhs, rhs) => lhs.evaluate(row)?.checked_mul(&rhs.evaluate(row)?)?,
@@ -245,8 +249,7 @@ impl Expression {
245249
},
246250
Self::Remainder(lhs, rhs) => lhs.evaluate(row)?.checked_rem(&rhs.evaluate(row)?)?,
247251
Self::SquareRoot(expr) => match expr.evaluate(row)? {
248-
Integer(i) if i < 0 => return errinput!("can't take negative square root"),
249-
Integer(i) => Float((i as f64).sqrt()),
252+
Integer(i @ 0..) => Float((i as f64).sqrt()),
250253
Float(f) => Float(f.sqrt()),
251254
Null => Null,
252255
value => return errinput!("can't take square root of {value}"),
@@ -259,7 +262,7 @@ impl Expression {
259262
Self::Like(lhs, rhs) => match (lhs.evaluate(row)?, rhs.evaluate(row)?) {
260263
(String(lhs), String(rhs)) => {
261264
// We could precompile the pattern if it's constant, instead
262-
// of recompiling it for every row, but this is fine.
265+
// of recompiling it for every row, but we keep it simple.
263266
let pattern =
264267
format!("^{}$", regex::escape(&rhs).replace('%', ".*").replace('_', "."));
265268
Boolean(Regex::new(&pattern)?.is_match(&lhs))
@@ -350,12 +353,13 @@ impl Expression {
350353
}
351354

352355
/// Converts the expression into conjunctive normal form, i.e. an AND of
353-
/// ORs, which is useful when optimizing plans. This is done by converting
354-
/// to negation normal form and then applying De Morgan's distributive law.
356+
/// ORs, useful during plan optimization. This is done by converting to
357+
/// negation normal form and then applying De Morgan's distributive law.
355358
pub fn into_cnf(self) -> Self {
356-
use Expression::*;
359+
use Expression::{And, Or};
360+
357361
let xform = |expr| {
358-
// We can't use a single match, since it needs deref patterns.
362+
// Can't use a single match; needs deref patterns.
359363
let Or(lhs, rhs) = expr else {
360364
return expr;
361365
};
@@ -371,13 +375,30 @@ impl Expression {
371375
self.into_nnf().transform(&|e| Ok(xform(e)), &Ok).unwrap() // infallible
372376
}
373377

378+
/// Converts the expression into conjunctive normal form as a vector of
379+
/// ANDed expressions (instead of nested ANDs).
380+
pub fn into_cnf_vec(self) -> Vec<Self> {
381+
let mut cnf = Vec::new();
382+
let mut stack = vec![self.into_cnf()];
383+
while let Some(expr) = stack.pop() {
384+
if let Self::And(lhs, rhs) = expr {
385+
stack.extend([*rhs, *lhs]); // push lhs last to pop it first
386+
} else {
387+
cnf.push(expr);
388+
}
389+
}
390+
cnf
391+
}
392+
374393
/// Converts the expression into negation normal form. This pushes NOT
375394
/// operators into the tree using De Morgan's laws, such that they're always
376395
/// below other logical operators. It is a useful intermediate form for
377396
/// applying other logical normalizations.
378397
pub fn into_nnf(self) -> Self {
379-
use Expression::*;
398+
use Expression::{And, Not, Or};
399+
380400
let xform = |expr| {
401+
// Can't use a single match; needs deref patterns.
381402
let Not(inner) = expr else {
382403
return expr;
383404
};
@@ -392,22 +413,7 @@ impl Expression {
392413
expr => Not(expr.into()),
393414
}
394415
};
395-
self.transform(&|e| Ok(xform(e)), &Ok).unwrap() // never fails
396-
}
397-
398-
/// Converts the expression into conjunctive normal form as a vector of
399-
/// ANDed expressions (instead of nested ANDs).
400-
pub fn into_cnf_vec(self) -> Vec<Self> {
401-
let mut cnf = Vec::new();
402-
let mut stack = vec![self.into_cnf()];
403-
while let Some(expr) = stack.pop() {
404-
if let Self::And(lhs, rhs) = expr {
405-
stack.extend([*rhs, *lhs]); // push lhs last to pop it first
406-
} else {
407-
cnf.push(expr);
408-
}
409-
}
410-
cnf
416+
self.transform(&|e| Ok(xform(e)), &Ok).unwrap() // infallible
411417
}
412418

413419
/// Creates an expression by ANDing together a vector, or None if empty.
@@ -424,6 +430,7 @@ impl Expression {
424430
/// = or IS NULL/NAN for a single column), returning the column index.
425431
pub fn is_column_lookup(&self) -> Option<usize> {
426432
use Expression::*;
433+
427434
match &self {
428435
// Column/constant equality can use index lookups. NULL and NaN are
429436
// handled in into_column_values().
@@ -451,6 +458,7 @@ impl Expression {
451458
/// must return true for the expression.
452459
pub fn into_column_values(self, index: usize) -> Vec<Value> {
453460
use Expression::*;
461+
454462
match self {
455463
Equal(lhs, rhs) => match (*lhs, *rhs) {
456464
(Column(column), Constant(value)) | (Constant(value), Column(column)) => {
@@ -478,7 +486,7 @@ impl Expression {
478486
}
479487
}
480488

481-
/// Replaces column references with the given column.
489+
/// Replaces column references from → to.
482490
pub fn replace_column(self, from: usize, to: usize) -> Self {
483491
let xform = |expr| match expr {
484492
Expression::Column(i) if i == from => Expression::Column(to),
@@ -487,7 +495,7 @@ impl Expression {
487495
self.transform(&|e| Ok(xform(e)), &Ok).unwrap() // infallible
488496
}
489497

490-
/// Shifts column references by the given amount.
498+
/// Shifts column references by the given amount (can be negative).
491499
pub fn shift_column(self, diff: isize) -> Self {
492500
let xform = |expr| match expr {
493501
Expression::Column(i) => Expression::Column((i as isize + diff) as usize),

0 commit comments

Comments
 (0)