@@ -16,17 +16,38 @@ function normalizeSQL(sql: string) {
1616
1717async function loadAllowlist ( dataSource : DataSource ) : Promise < string [ ] > {
1818 try {
19- const statement = 'SELECT sql_statement FROM tmp_allowlist_queries'
19+ const statement =
20+ 'SELECT sql_statement, source FROM tmp_allowlist_queries'
2021 const result = ( await dataSource . rpc . executeQuery ( {
2122 sql : statement ,
2223 } ) ) as QueryResult [ ]
23- return result . map ( ( row ) => String ( row . sql_statement ) )
24+ return result
25+ . filter ( ( row ) => row . source === dataSource . source )
26+ . map ( ( row ) => String ( row . sql_statement ) )
2427 } catch ( error ) {
2528 console . error ( 'Error loading allowlist:' , error )
2629 return [ ]
2730 }
2831}
2932
33+ async function addRejectedQuery (
34+ query : string ,
35+ dataSource : DataSource
36+ ) : Promise < string [ ] > {
37+ try {
38+ const statement =
39+ 'INSERT INTO tmp_allowlist_rejections (sql_statement, source) VALUES (?, ?)'
40+ const result = ( await dataSource . rpc . executeQuery ( {
41+ sql : statement ,
42+ params : [ query , dataSource . source ] ,
43+ } ) ) as QueryResult [ ]
44+ return result . map ( ( row ) => String ( row . sql_statement ) )
45+ } catch ( error ) {
46+ console . error ( 'Error inserting rejected allowlist query:' , error )
47+ return [ ]
48+ }
49+ }
50+
3051export async function isQueryAllowed ( opts : {
3152 sql : string
3253 isEnabled : boolean
@@ -59,34 +80,78 @@ export async function isQueryAllowed(opts: {
5980 const normalizedQuery = parser . astify ( normalizeSQL ( sql ) )
6081
6182 // Compare ASTs while ignoring specific values
62- const isCurrentAllowed = normalizedAllowlist ?. some ( ( allowedQuery ) => {
63- // Create deep copies to avoid modifying original ASTs
64- const allowedAst = JSON . parse ( JSON . stringify ( allowedQuery ) )
65- const queryAst = JSON . parse ( JSON . stringify ( normalizedQuery ) )
66-
67- // Remove or normalize value fields from both ASTs
68- const normalizeAst = ( ast : any ) => {
69- if ( Array . isArray ( ast ) ) {
70- ast . forEach ( normalizeAst )
71- } else if ( ast && typeof ast === 'object' ) {
72- // Remove or normalize fields that contain specific values
73- if ( 'value' in ast ) {
74- ast . value = '?'
75- }
76-
77- Object . values ( ast ) . forEach ( normalizeAst )
78- }
79-
80- return ast
83+ // const isCurrentAllowed = normalizedAllowlist?.some((allowedQuery) => {
84+ // // Create deep copies to avoid modifying original ASTs
85+ // const allowedAst = JSON.parse(JSON.stringify(allowedQuery))
86+ // const queryAst = JSON.parse(JSON.stringify(normalizedQuery))
87+
88+ // // Remove or normalize value fields from both ASTs
89+ // const normalizeAst = (ast: any) => {
90+ // if (Array.isArray(ast)) {
91+ // ast.forEach(normalizeAst)
92+ // } else if (ast && typeof ast === 'object') {
93+ // // Remove or normalize fields that contain specific values
94+ // if ('value' in ast) {
95+ // // Preserve the value for specific clauses like LIMIT
96+ // if (ast.as === 'limit' || ast.type === 'limit') {
97+ // // Do not normalize LIMIT values
98+ // return;
99+ // }
100+ // ast.value = '?'; // Normalize other values
101+ // }
102+
103+ // // Recursively normalize all other fields
104+ // Object.values(ast).forEach(normalizeAst)
105+ // }
106+
107+ // return ast;
108+ // };
109+
110+ // normalizeAst(allowedAst)
111+ // normalizeAst(queryAst)
112+
113+ // return JSON.stringify(allowedAst) === JSON.stringify(queryAst)
114+ // })
115+
116+ const deepCompareAst = ( allowedAst : any , queryAst : any ) : boolean => {
117+ if ( typeof allowedAst !== typeof queryAst ) return false
118+
119+ if ( Array . isArray ( allowedAst ) && Array . isArray ( queryAst ) ) {
120+ if ( allowedAst . length !== queryAst . length ) return false
121+ return allowedAst . every ( ( item , index ) =>
122+ deepCompareAst ( item , queryAst [ index ] )
123+ )
124+ } else if (
125+ typeof allowedAst === 'object' &&
126+ allowedAst !== null &&
127+ queryAst !== null
128+ ) {
129+ const allowedKeys = Object . keys ( allowedAst )
130+ const queryKeys = Object . keys ( queryAst )
131+
132+ if ( allowedKeys . length !== queryKeys . length ) return false
133+
134+ return allowedKeys . every ( ( key ) =>
135+ deepCompareAst ( allowedAst [ key ] , queryAst [ key ] )
136+ )
81137 }
82138
83- normalizeAst ( allowedAst )
84- normalizeAst ( queryAst )
139+ // Base case: Primitive value comparison
140+ return allowedAst === queryAst
141+ }
85142
86- return JSON . stringify ( allowedAst ) === JSON . stringify ( queryAst )
87- } )
143+ const isCurrentAllowed = normalizedAllowlist ?. some ( ( allowedQuery ) =>
144+ deepCompareAst ( allowedQuery , normalizedQuery )
145+ )
88146
89147 if ( ! isCurrentAllowed ) {
148+ // For any rejected query, we can add it to a table of rejected queries
149+ // to act both as an audit log as well as an easy way to see recent queries
150+ // that may need to be added to the allowlist in an easy way via a user
151+ // interface.
152+ addRejectedQuery ( sql , dataSource )
153+
154+ // Then throw the appropriate error to the user.
90155 throw new Error ( 'Query not allowed' )
91156 }
92157
0 commit comments