Skip to content

Commit 1053072

Browse files
authored
Add support for validating SQL queries with the WITH clause and extend unit tests (#1327)
1 parent a81307f commit 1053072

File tree

2 files changed

+122
-56
lines changed
  • dataframe-jdbc/src

2 files changed

+122
-56
lines changed

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,11 @@ private fun hasForbiddenPatterns(input: String): Boolean {
557557
return false
558558
}
559559

560+
/**
561+
* Allowed list of SQL operators
562+
*/
563+
private val ALLOWED_SQL_OPERATORS = listOf("SELECT", "WITH", "VALUES", "TABLE")
564+
560565
/**
561566
* Validates if the SQL query is safe and starts with SELECT.
562567
* Ensures a proper syntax structure, checks for balanced quotes, and disallows dangerous commands or patterns.
@@ -567,9 +572,11 @@ private fun isValidSqlQuery(sqlQuery: String): Boolean {
567572
// Log the query being validated
568573
logger.warn { "Validating SQL query: '$sqlQuery'" }
569574

570-
// Ensure the query starts with "SELECT"
571-
if (!normalizedSqlQuery.startsWith("SELECT")) {
572-
logger.error { "Validation failed: The SQL query must start with 'SELECT'. Given query: '$sqlQuery'." }
575+
// Ensure the query starts from one of the allowed SQL operators
576+
if (ALLOWED_SQL_OPERATORS.none { normalizedSqlQuery.startsWith(it) }) {
577+
logger.error {
578+
"Validation failed: The SQL query must start with one of: $ALLOWED_SQL_OPERATORS. Given query: '$sqlQuery'."
579+
}
573580
return false
574581
}
575582

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/postgresH2Test.kt

Lines changed: 112 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -78,69 +78,75 @@ class PostgresH2Test {
7878
connection = DriverManager.getConnection(URL)
7979

8080
@Language("SQL")
81-
val createTableStatement = """
81+
val createTableStatement =
82+
"""
8283
CREATE TABLE IF NOT EXISTS table1 (
83-
id serial PRIMARY KEY,
84-
bigintCol bigint not null,
85-
smallintCol smallint not null,
86-
bigserialCol bigserial not null,
87-
booleanCol boolean not null,
88-
byteaCol bytea not null,
89-
characterCol character not null,
90-
characterNCol character(10) not null,
91-
charCol char not null,
92-
dateCol date not null,
93-
doubleCol double precision not null,
94-
integerCol integer,
95-
intArrayCol integer array,
96-
doubleArrayCol double precision array,
97-
dateArrayCol date array,
98-
textArrayCol text array,
99-
booleanArrayCol boolean array
100-
)
101-
"""
84+
id serial PRIMARY KEY,
85+
bigintCol bigint not null,
86+
smallintCol smallint not null,
87+
bigserialCol bigserial not null,
88+
booleanCol boolean not null,
89+
byteaCol bytea not null,
90+
characterCol character not null,
91+
characterNCol character(10) not null,
92+
charCol char not null,
93+
dateCol date not null,
94+
doubleCol double precision not null,
95+
integerCol integer,
96+
intArrayCol integer array,
97+
doubleArrayCol double precision array,
98+
dateArrayCol date array,
99+
textArrayCol text array,
100+
booleanArrayCol boolean array
101+
)
102+
""".trimIndent()
103+
102104
connection.createStatement().execute(createTableStatement.trimIndent())
103105

104106
@Language("SQL")
105-
val createTableQuery = """
107+
val createTableQuery =
108+
"""
106109
CREATE TABLE IF NOT EXISTS table2 (
107-
id serial PRIMARY KEY,
108-
moneyCol money not null,
109-
numericCol numeric not null,
110-
realCol real not null,
111-
smallintCol smallint not null,
112-
serialCol serial not null,
113-
textCol text,
114-
timeCol time not null,
115-
timeWithZoneCol time with time zone not null,
116-
timestampCol timestamp not null,
117-
timestampWithZoneCol timestamp with time zone not null,
118-
uuidCol uuid not null
119-
)
120-
"""
110+
id serial PRIMARY KEY,
111+
moneyCol money not null,
112+
numericCol numeric not null,
113+
realCol real not null,
114+
smallintCol smallint not null,
115+
serialCol serial not null,
116+
textCol text,
117+
timeCol time not null,
118+
timeWithZoneCol time with time zone not null,
119+
timestampCol timestamp not null,
120+
timestampWithZoneCol timestamp with time zone not null,
121+
uuidCol uuid not null
122+
)
123+
""".trimIndent()
124+
121125
connection.createStatement().execute(createTableQuery.trimIndent())
122126

123127
@Language("SQL")
124-
val insertData1 = """
125-
INSERT INTO table1 (
126-
bigintCol, smallintCol, bigserialCol, booleanCol,
127-
byteaCol, characterCol, characterNCol, charCol,
128-
dateCol, doubleCol,
129-
integerCol, intArrayCol,
130-
doubleArrayCol, dateArrayCol, textArrayCol, booleanArrayCol
131-
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
132-
"""
128+
val insertData1 =
129+
"""
130+
INSERT INTO table1 (
131+
bigintCol, smallintCol, bigserialCol, booleanCol,
132+
byteaCol, characterCol, characterNCol, charCol,
133+
dateCol, doubleCol,
134+
integerCol, intArrayCol,
135+
doubleArrayCol, dateArrayCol, textArrayCol, booleanArrayCol
136+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
137+
""".trimIndent()
133138

134139
@Language("SQL")
135-
val insertData2 = """
136-
INSERT INTO table2 (
137-
moneyCol, numericCol,
138-
realCol, smallintCol,
139-
serialCol, textCol, timeCol,
140-
timeWithZoneCol, timestampCol, timestampWithZoneCol,
141-
uuidCol
142-
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
143-
"""
140+
val insertData2 =
141+
"""
142+
INSERT INTO table2 (
143+
moneyCol, numericCol,
144+
realCol, smallintCol,
145+
serialCol, textCol, timeCol,
146+
timeWithZoneCol, timestampCol, timestampWithZoneCol,
147+
uuidCol
148+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
149+
""".trimIndent()
144150

145151
val intArray = connection.createArrayOf("INTEGER", arrayOf(1, 2, 3))
146152
val doubleArray = connection.createArrayOf("DOUBLE", arrayOf(1.1, 2.2, 3.3))
@@ -327,4 +333,57 @@ class PostgresH2Test {
327333
fun `infer nullability`() {
328334
inferNullability(connection)
329335
}
336+
337+
@Test
338+
fun `readSqlQuery should execute a WITH clause and return results`() {
339+
try {
340+
// Step 1: Create a temporary table
341+
@Language("SQL")
342+
val createTableQuery =
343+
"""
344+
CREATE TABLE employees (
345+
id INT PRIMARY KEY,
346+
name VARCHAR(100),
347+
salary DOUBLE
348+
)
349+
""".trimIndent()
350+
connection.createStatement().execute(createTableQuery)
351+
352+
// Step 2: Insert data into the table
353+
@Language("SQL")
354+
val insertDataQuery =
355+
"""
356+
INSERT INTO employees (id, name, salary) VALUES
357+
(1, 'Alice', 60000.0),
358+
(2, 'Bob', 50000.0),
359+
(3, 'Charlie', 70000.0)
360+
""".trimIndent()
361+
362+
connection.createStatement().execute(insertDataQuery)
363+
364+
// Step 3: Execute the query with a WITH clause
365+
@Language("SQL")
366+
val queryWithClause =
367+
"""
368+
WITH high_earners AS (
369+
SELECT name, salary
370+
FROM employees
371+
WHERE salary > 55000.0
372+
)
373+
SELECT * FROM high_earners
374+
""".trimIndent()
375+
376+
val resultDataFrame = DataFrame.readSqlQuery(connection, queryWithClause)
377+
378+
// Step 4: Validate the results
379+
resultDataFrame.rowsCount() shouldBe 2
380+
resultDataFrame[0][0] shouldBe "Alice"
381+
resultDataFrame[1][0] shouldBe "Charlie"
382+
} finally {
383+
// Step 5: Clean up the temporary table
384+
@Language("SQL")
385+
val dropTableQuery = "DROP TABLE IF EXISTS employees"
386+
connection.createStatement().execute(dropTableQuery)
387+
}
388+
}
330389
}

0 commit comments

Comments
 (0)