Skip to content

Commit 5df824a

Browse files
committed
store: Extract SQL parsing tests into a YAML file
That setup makes it much easier to add more tests that check that we scrub dangerous constructs from SQL
1 parent 37dafa8 commit 5df824a

File tree

2 files changed

+153
-62
lines changed

2 files changed

+153
-62
lines changed

store/postgres/src/sql/parser.rs

Lines changed: 98 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,32 @@ impl Parser {
3434

3535
#[cfg(test)]
3636
mod test {
37+
use std::sync::Arc;
3738

38-
use graph::prelude::BLOCK_NUMBER_MAX;
39+
use crate::sql::{parser::SQL_DIALECT, test::make_layout};
40+
use graph::prelude::{lazy_static, serde_yaml, BLOCK_NUMBER_MAX};
41+
use serde::{Deserialize, Serialize};
3942

40-
use crate::sql::test::make_layout;
43+
use pretty_assertions::assert_eq;
4144

42-
use super::*;
45+
use super::Parser;
4346

4447
const TEST_GQL: &str = "
45-
type SwapMulti @entity(immutable: true) {
48+
type Swap @entity(immutable: true) {
4649
id: Bytes!
47-
sender: Bytes! # address
48-
amountsIn: [BigInt!]! # uint256[]
49-
tokensIn: [Bytes!]! # address[]
50-
amountsOut: [BigInt!]! # uint256[]
51-
tokensOut: [Bytes!]! # address[]
52-
referralCode: BigInt! # uint32
53-
blockNumber: BigInt!
54-
blockTimestamp: BigInt!
55-
transactionHash: Bytes!
50+
timestamp: BigInt!
51+
pool: Bytes!
52+
token0: Bytes!
53+
token1: Bytes!
54+
sender: Bytes!
55+
recipient: Bytes!
56+
origin: Bytes! # the EOA that initiated the txn
57+
amount0: BigDecimal!
58+
amount1: BigDecimal!
59+
amountUSD: BigDecimal!
60+
sqrtPriceX96: BigInt!
61+
tick: BigInt!
62+
logIndex: BigInt
5663
}
5764
5865
type Token @entity {
@@ -64,64 +71,93 @@ mod test {
6471
}
6572
";
6673

67-
const SQL_QUERY: &str = "
68-
with tokens as (
69-
select * from (values
70-
('0x0000000000000000000000000000000000000000','ETH','Ethereum',18),
71-
('0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48','USDC','USD Coin',6)
72-
) as t(address,symbol,name,decimals)
73-
)
74-
75-
select
76-
date,
77-
t.symbol,
78-
SUM(amount)/pow(10,t.decimals) as amount
79-
from (select
80-
date(to_timestamp(block_timestamp) at time zone 'utc') as date,
81-
token,
82-
amount
83-
from swap_multi as sm
84-
,unnest(sm.amounts_in,sm.tokens_in) as smi(amount,token)
85-
union all
86-
select
87-
date(to_timestamp(block_timestamp) at time zone 'utc') as date,
88-
token,
89-
amount
90-
from sgd1.swap_multi as sm
91-
,unnest(sm.amounts_out,sm.tokens_out) as smo(amount,token)
92-
) as tp
93-
inner join tokens as t on t.address = '0x' || encode(tp.token,'hex')
94-
group by tp.date,t.symbol,t.decimals
95-
order by tp.date desc ,amount desc
96-
97-
";
98-
9974
fn parse_and_validate(sql: &str) -> Result<String, anyhow::Error> {
10075
let parser = Parser::new(Arc::new(make_layout(TEST_GQL)), BLOCK_NUMBER_MAX);
10176

10277
parser.parse_and_validate(sql)
10378
}
10479

105-
#[test]
106-
fn parse_sql() {
107-
let query = parse_and_validate(SQL_QUERY).unwrap();
80+
#[derive(Debug, Serialize, Deserialize)]
81+
struct TestCase {
82+
name: Option<String>,
83+
sql: String,
84+
ok: Option<String>,
85+
err: Option<String>,
86+
}
10887

109-
assert_eq!(
110-
query,
111-
r#"WITH "swap_multi" AS (SELECT concat('0x', encode("id", 'hex')) AS "id", concat('0x', encode("sender", 'hex')) AS "sender", "amounts_in", "tokens_in", "amounts_out", "tokens_out", "referral_code", "block_number", "block_timestamp", concat('0x', encode("transaction_hash", 'hex')) AS "transaction_hash", "block$" FROM "sgd0815"."swap_multi"),
112-
"token" AS (SELECT "id", concat('0x', encode("address", 'hex')) AS "address", "symbol", "name", "decimals", "block_range" FROM "sgd0815"."token" WHERE "block_range" @> 2147483647) SELECT to_jsonb(sub.*) AS data FROM ( WITH tokens AS (SELECT * FROM (VALUES ('0x0000000000000000000000000000000000000000', 'ETH', 'Ethereum', 18), ('0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', 'USDC', 'USD Coin', 6)) AS t (address, symbol, name, decimals)) SELECT date, t.symbol, SUM(amount) / pow(10, t.decimals) AS amount FROM (SELECT date(to_timestamp(block_timestamp) AT TIME ZONE 'utc') AS date, token, amount FROM "swap_multi" AS sm, UNNEST(sm.amounts_in, sm.tokens_in) AS smi (amount, token) UNION ALL SELECT date(to_timestamp(block_timestamp) AT TIME ZONE 'utc') AS date, token, amount FROM "swap_multi" AS sm, UNNEST(sm.amounts_out, sm.tokens_out) AS smo (amount, token)) AS tp JOIN tokens AS t ON t.address = '0x' || encode(tp.token, 'hex') GROUP BY tp.date, t.symbol, t.decimals ORDER BY tp.date DESC, amount DESC ) AS sub"#
113-
);
88+
impl TestCase {
89+
fn fail(
90+
&self,
91+
name: &str,
92+
msg: &str,
93+
exp: impl std::fmt::Display,
94+
actual: impl std::fmt::Display,
95+
) {
96+
panic!(
97+
"case {name} failed: {}\n expected: {}\n actual: {}",
98+
msg, exp, actual
99+
);
100+
}
101+
102+
fn run(&self, num: usize) {
103+
fn normalize(query: &str) -> String {
104+
sqlparser::parser::Parser::parse_sql(&SQL_DIALECT, query)
105+
.unwrap()
106+
.pop()
107+
.unwrap()
108+
.to_string()
109+
}
110+
111+
let name = self
112+
.name
113+
.as_ref()
114+
.map(|name| format!("{num} ({name})"))
115+
.unwrap_or_else(|| num.to_string());
116+
let result = parse_and_validate(&self.sql);
117+
118+
match (&self.ok, &self.err, result) {
119+
(Some(expected), None, Ok(actual)) => {
120+
let actual = normalize(&actual);
121+
let expected = normalize(expected);
122+
assert_eq!(actual, expected, "case {} failed", name);
123+
}
124+
(None, Some(expected), Err(actual)) => {
125+
let actual = actual.to_string();
126+
if !actual.contains(expected) {
127+
self.fail(&name, "expected error message not found", expected, actual);
128+
}
129+
}
130+
(Some(_), Some(_), _) => {
131+
panic!("case {} has both ok and err", name);
132+
}
133+
(None, None, _) => {
134+
panic!("case {} has neither ok nor err", name)
135+
}
136+
(None, Some(exp), Ok(actual)) => {
137+
self.fail(&name, "expected an error", exp, actual);
138+
}
139+
(Some(exp), None, Err(actual)) => self.fail(&name, "expected success", exp, actual),
140+
}
141+
}
114142
}
115143

116-
#[test]
117-
fn parse_simple_sql() {
118-
let query =
119-
parse_and_validate("select symbol, address from token where decimals > 10").unwrap();
144+
lazy_static! {
145+
static ref TESTS: Vec<TestCase> = {
146+
let file = std::path::PathBuf::from_iter([
147+
env!("CARGO_MANIFEST_DIR"),
148+
"src",
149+
"sql",
150+
"parser_tests.yaml",
151+
]);
152+
let tests = std::fs::read_to_string(file).unwrap();
153+
serde_yaml::from_str(&tests).unwrap()
154+
};
155+
}
120156

121-
assert_eq!(
122-
query,
123-
r#"select to_jsonb(sub.*) as data from ( SELECT symbol, address FROM (SELECT * FROM "sgd0815"."token" WHERE block_range @> 2147483647) AS token WHERE decimals > 10 ) as sub"#
124-
);
125-
println!("{}", query);
157+
#[test]
158+
fn parse_sql() {
159+
for (num, case) in TESTS.iter().enumerate() {
160+
case.run(num);
161+
}
126162
}
127163
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Test cases for the SQL parser. Each test case has the following fields:
2+
# name : an optional name for error messages
3+
# sql : the SQL query to parse
4+
# ok : the expected rewritten query
5+
# err : a part of the error message if parsing should fail
6+
# Of course, only one of ok and err can be specified
7+
8+
- sql: select symbol, address from token where decimals > 10
9+
ok: >
10+
select to_jsonb(sub.*) as data from (
11+
SELECT symbol, address FROM (
12+
SELECT * FROM "sgd0815"."token" WHERE block_range @> 2147483647) AS token
13+
WHERE decimals > 10 ) as sub
14+
- sql: >
15+
with tokens as (
16+
select * from (values
17+
('0x0000000000000000000000000000000000000000','ETH','Ethereum',18),
18+
('0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48','USDC','USD Coin',6)
19+
) as t(address,symbol,name,decimals)
20+
)
21+
22+
select date, t.symbol, SUM(amount)/pow(10,t.decimals) as amount
23+
from (select
24+
date(to_timestamp(block_timestamp) at time zone 'utc') as date,
25+
token, amount
26+
from swap as sm,
27+
unnest(sm.amounts_in,sm.tokens_in) as smi(amount,token)
28+
union all
29+
select
30+
date(to_timestamp(block_timestamp) at time zone 'utc') as date,
31+
token, amount
32+
from swap as sm,
33+
unnest(sm.amounts_out,sm.tokens_out) as smo(amount,token)) as tp
34+
inner join
35+
tokens as t on t.address = tp.token
36+
group by tp.date, t.symbol, t.decimals
37+
order by tp.date desc, amount desc
38+
ok: >
39+
select to_jsonb(sub.*) as data from (
40+
WITH tokens AS (
41+
SELECT * FROM (
42+
VALUES ('0x0000000000000000000000000000000000000000', 'ETH', 'Ethereum', 18),
43+
('0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', 'USDC', 'USD Coin', 6))
44+
AS t (address, symbol, name, decimals))
45+
SELECT date, t.symbol, SUM(amount) / pow(10, t.decimals) AS amount
46+
FROM (SELECT date(to_timestamp(block_timestamp) AT TIME ZONE 'utc') AS date, token, amount
47+
FROM (SELECT * FROM "sgd0815"."swap" WHERE block$ <= 2147483647) AS sm,
48+
UNNEST(sm.amounts_in, sm.tokens_in) AS smi (amount, token)
49+
UNION ALL
50+
SELECT date(to_timestamp(block_timestamp) AT TIME ZONE 'utc') AS date, token, amount
51+
FROM (SELECT * FROM "sgd0815"."swap" WHERE block$ <= 2147483647) AS sm,
52+
UNNEST(sm.amounts_out, sm.tokens_out) AS smo (amount, token)) AS tp
53+
JOIN tokens AS t ON t.address = tp.token
54+
GROUP BY tp.date, t.symbol, t.decimals
55+
ORDER BY tp.date DESC, amount DESC ) as sub

0 commit comments

Comments
 (0)