@@ -16,17 +16,38 @@ function normalizeSQL(sql: string) {
1616
1717async function loadAllowlist ( dataSource : DataSource ) : Promise < string [ ] > {
1818 try {
19- const statement = 'SELECT sql_statement FROM tmp_allowlist_queries'
19+ const statement =
20+ 'SELECT sql_statement, source FROM tmp_allowlist_queries'
2021 const result = ( await dataSource . rpc . executeQuery ( {
2122 sql : statement ,
2223 } ) ) as QueryResult [ ]
23- return result . map ( ( row ) => String ( row . sql_statement ) )
24+ return result
25+ . filter ( ( row ) => row . source === dataSource . source )
26+ . map ( ( row ) => String ( row . sql_statement ) )
2427 } catch ( error ) {
2528 console . error ( 'Error loading allowlist:' , error )
2629 return [ ]
2730 }
2831}
2932
33+ async function addRejectedQuery (
34+ query : string ,
35+ dataSource : DataSource
36+ ) : Promise < string [ ] > {
37+ try {
38+ const statement =
39+ 'INSERT INTO tmp_allowlist_rejections (sql_statement, source) VALUES (?, ?)'
40+ const result = ( await dataSource . rpc . executeQuery ( {
41+ sql : statement ,
42+ params : [ query , dataSource . source ] ,
43+ } ) ) as QueryResult [ ]
44+ return result . map ( ( row ) => String ( row . sql_statement ) )
45+ } catch ( error ) {
46+ console . error ( 'Error inserting rejected allowlist query:' , error )
47+ return [ ]
48+ }
49+ }
50+
3051export async function isQueryAllowed ( opts : {
3152 sql : string
3253 isEnabled : boolean
@@ -59,34 +80,45 @@ export async function isQueryAllowed(opts: {
5980 const normalizedQuery = parser . astify ( normalizeSQL ( sql ) )
6081
6182 // Compare ASTs while ignoring specific values
62- const isCurrentAllowed = normalizedAllowlist ?. some ( ( allowedQuery ) => {
63- // Create deep copies to avoid modifying original ASTs
64- const allowedAst = JSON . parse ( JSON . stringify ( allowedQuery ) )
65- const queryAst = JSON . parse ( JSON . stringify ( normalizedQuery ) )
66-
67- // Remove or normalize value fields from both ASTs
68- const normalizeAst = ( ast : any ) => {
69- if ( Array . isArray ( ast ) ) {
70- ast . forEach ( normalizeAst )
71- } else if ( ast && typeof ast === 'object' ) {
72- // Remove or normalize fields that contain specific values
73- if ( 'value' in ast ) {
74- ast . value = '?'
75- }
76-
77- Object . values ( ast ) . forEach ( normalizeAst )
78- }
79-
80- return ast
83+ const deepCompareAst = ( allowedAst : any , queryAst : any ) : boolean => {
84+ if ( typeof allowedAst !== typeof queryAst ) return false
85+
86+ if ( Array . isArray ( allowedAst ) && Array . isArray ( queryAst ) ) {
87+ if ( allowedAst . length !== queryAst . length ) return false
88+ return allowedAst . every ( ( item , index ) =>
89+ deepCompareAst ( item , queryAst [ index ] )
90+ )
91+ } else if (
92+ typeof allowedAst === 'object' &&
93+ allowedAst !== null &&
94+ queryAst !== null
95+ ) {
96+ const allowedKeys = Object . keys ( allowedAst )
97+ const queryKeys = Object . keys ( queryAst )
98+
99+ if ( allowedKeys . length !== queryKeys . length ) return false
100+
101+ return allowedKeys . every ( ( key ) =>
102+ deepCompareAst ( allowedAst [ key ] , queryAst [ key ] )
103+ )
81104 }
82105
83- normalizeAst ( allowedAst )
84- normalizeAst ( queryAst )
106+ // Base case: Primitive value comparison
107+ return allowedAst === queryAst
108+ }
85109
86- return JSON . stringify ( allowedAst ) === JSON . stringify ( queryAst )
87- } )
110+ const isCurrentAllowed = normalizedAllowlist ?. some ( ( allowedQuery ) =>
111+ deepCompareAst ( allowedQuery , normalizedQuery )
112+ )
88113
89114 if ( ! isCurrentAllowed ) {
115+ // For any rejected query, we can add it to a table of rejected queries
116+ // to act both as an audit log as well as an easy way to see recent queries
117+ // that may need to be added to the allowlist in an easy way via a user
118+ // interface.
119+ addRejectedQuery ( sql , dataSource )
120+
121+ // Then throw the appropriate error to the user.
90122 throw new Error ( 'Query not allowed' )
91123 }
92124
0 commit comments