Skip to content

Commit 958f1df

Browse files
authored
Add extension functions for the ResultSet (#772)
* Add extension functions for the ResultSet * added extension functions for Connection, DatabaseConfiguration * Refactor database configuration and dataframe methods. Renamed `DatabaseConfiguration` to `DbConnectionConfig` for clarity. Replaced `.toDF` with `.readDataFrame` methods to improve method naming consistency. These changes enhance code readability and maintainability. * Refactor SQL reading and schema functions in readJdbc.kt Simplify the logic to use single-expression functions for readability. Ensure consistent formatting and make error messages more explicit. This change also corrects minor indentation issues in SQL query strings within tests. * Rename DatabaseConfiguration to DbConnectionConfig for consistency This commit updates various imports and references from DatabaseConfiguration to DbConnectionConfig across different files. This change ensures consistency in the naming convention used throughout the codebase and documentation, improving clarity and maintenance.
1 parent 73ba813 commit 958f1df

File tree

5 files changed

+385
-67
lines changed

5 files changed

+385
-67
lines changed

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt

Lines changed: 194 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public data class TableMetadata(val name: String, val schemaName: String?, val c
9898
* @property [user] the username used for authentication (optional, default is empty string).
9999
* @property [password] the password used for authentication (optional, default is empty string).
100100
*/
101-
public data class DatabaseConfiguration(val url: String, val user: String = "", val password: String = "")
101+
public data class DbConnectionConfig(val url: String, val user: String = "", val password: String = "")
102102

103103
/**
104104
* Reads data from an SQL table and converts it into a DataFrame.
@@ -110,7 +110,7 @@ public data class DatabaseConfiguration(val url: String, val user: String = "",
110110
* @return the DataFrame containing the data from the SQL table.
111111
*/
112112
public fun DataFrame.Companion.readSqlTable(
113-
dbConfig: DatabaseConfiguration,
113+
dbConfig: DbConnectionConfig,
114114
tableName: String,
115115
limit: Int = DEFAULT_LIMIT,
116116
inferNullability: Boolean = true,
@@ -169,7 +169,7 @@ public fun DataFrame.Companion.readSqlTable(
169169
* @return the DataFrame containing the result of the SQL query.
170170
*/
171171
public fun DataFrame.Companion.readSqlQuery(
172-
dbConfig: DatabaseConfiguration,
172+
dbConfig: DbConnectionConfig,
173173
sqlQuery: String,
174174
limit: Int = DEFAULT_LIMIT,
175175
inferNullability: Boolean = true,
@@ -218,6 +218,89 @@ public fun DataFrame.Companion.readSqlQuery(
218218
}
219219
}
220220

221+
/**
222+
* Converts the result of an SQL query or SQL table (by name) to the DataFrame.
223+
*
224+
* @param [sqlQueryOrTableName] the SQL query to execute or name of the SQL table.
225+
* It should be a name of one of the existing SQL tables,
226+
* or the SQL query should start from SELECT and contain one query for reading data without any manipulation.
227+
* It should not contain `;` symbol.
228+
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
229+
* @param [inferNullability] indicates how the column nullability should be inferred.
230+
* @return the DataFrame containing the result of the SQL query.
231+
*/
232+
public fun DbConnectionConfig.readDataFrame(
233+
sqlQueryOrTableName: String,
234+
limit: Int = DEFAULT_LIMIT,
235+
inferNullability: Boolean = true,
236+
): AnyFrame =
237+
when {
238+
isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery(
239+
this,
240+
sqlQueryOrTableName,
241+
limit,
242+
inferNullability,
243+
)
244+
245+
isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable(
246+
this,
247+
sqlQueryOrTableName,
248+
limit,
249+
inferNullability,
250+
)
251+
252+
else -> throw IllegalArgumentException(
253+
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
254+
)
255+
}
256+
257+
private fun isSqlQuery(sqlQueryOrTableName: String): Boolean {
258+
val queryPattern = Regex("(?i)\\b(SELECT)\\b")
259+
return queryPattern.containsMatchIn(sqlQueryOrTableName.trim())
260+
}
261+
262+
private fun isSqlTableName(sqlQueryOrTableName: String): Boolean {
263+
// Match table names with optional schema and catalog (e.g., catalog.schema.table)
264+
val tableNamePattern = Regex("^[a-zA-Z_][a-zA-Z0-9_]*(\\.[a-zA-Z_][a-zA-Z0-9_]*){0,2}$")
265+
return tableNamePattern.matches(sqlQueryOrTableName.trim())
266+
}
267+
268+
/**
269+
* Converts the result of an SQL query or SQL table (by name) to the DataFrame.
270+
*
271+
* @param [sqlQueryOrTableName] the SQL query to execute or name of the SQL table.
272+
* It should be a name of one of the existing SQL tables,
273+
* or the SQL query should start from SELECT and contain one query for reading data without any manipulation.
274+
* It should not contain `;` symbol.
275+
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
276+
* @param [inferNullability] indicates how the column nullability should be inferred.
277+
* @return the DataFrame containing the result of the SQL query.
278+
*/
279+
public fun Connection.readDataFrame(
280+
sqlQueryOrTableName: String,
281+
limit: Int = DEFAULT_LIMIT,
282+
inferNullability: Boolean = true,
283+
): AnyFrame =
284+
when {
285+
isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery(
286+
this,
287+
sqlQueryOrTableName,
288+
limit,
289+
inferNullability,
290+
)
291+
292+
isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable(
293+
this,
294+
sqlQueryOrTableName,
295+
limit,
296+
inferNullability,
297+
)
298+
299+
else -> throw IllegalArgumentException(
300+
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
301+
)
302+
}
303+
221304
/** SQL query is accepted only if it starts from SELECT */
222305
private fun isValid(sqlQuery: String): Boolean {
223306
val normalizedSqlQuery = sqlQuery.trim().uppercase()
@@ -256,6 +339,30 @@ public fun DataFrame.Companion.readResultSet(
256339
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit, inferNullability)
257340
}
258341

342+
/**
343+
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
344+
*
345+
* A [ResultSet][java.sql.ResultSet] object maintains a cursor pointing to its current row of data.
346+
* By default, a ResultSet object is not updatable and has a cursor that can only move forward.
347+
* Therefore, you can iterate through it only once, from the first row to the last row.
348+
*
349+
* For more details, refer to the official Java documentation on [ResultSet][java.sql.ResultSet].
350+
*
351+
* NOTE: Reading from the [ResultSet][java.sql.ResultSet] could potentially change its state.
352+
*
353+
* @param [dbType] the type of database that the [ResultSet] belongs to.
354+
* @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet].
355+
* @param [inferNullability] indicates how the column nullability should be inferred.
356+
* @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data.
357+
*
358+
* [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html
359+
*/
360+
public fun ResultSet.readDataFrame(
361+
dbType: DbType,
362+
limit: Int = DEFAULT_LIMIT,
363+
inferNullability: Boolean = true,
364+
): AnyFrame = DataFrame.Companion.readResultSet(this, dbType, limit, inferNullability)
365+
259366
/**
260367
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
261368
*
@@ -288,6 +395,31 @@ public fun DataFrame.Companion.readResultSet(
288395
return readResultSet(resultSet, dbType, limit, inferNullability)
289396
}
290397

398+
/**
399+
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
400+
*
401+
* A [ResultSet][java.sql.ResultSet] object maintains a cursor pointing to its current row of data.
402+
* By default, a ResultSet object is not updatable and has a cursor that can only move forward.
403+
* Therefore, you can iterate through it only once, from the first row to the last row.
404+
*
405+
* For more details, refer to the official Java documentation on [ResultSet][java.sql.ResultSet].
406+
*
407+
* NOTE: Reading from the [ResultSet][java.sql.ResultSet] could potentially change its state.
408+
*
409+
* @param [connection] the connection to the database (it's required to extract the database type)
410+
* that the [ResultSet] belongs to.
411+
* @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet].
412+
* @param [inferNullability] indicates how the column nullability should be inferred.
413+
* @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data.
414+
*
415+
* [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html
416+
*/
417+
public fun ResultSet.readDataFrame(
418+
connection: Connection,
419+
limit: Int = DEFAULT_LIMIT,
420+
inferNullability: Boolean = true,
421+
): AnyFrame = DataFrame.Companion.readResultSet(this, connection, limit, inferNullability)
422+
291423
/**
292424
* Reads all non-system tables from a database and returns them
293425
* as a map of SQL tables and corresponding dataframes using the provided database configuration and limit.
@@ -299,7 +431,7 @@ public fun DataFrame.Companion.readResultSet(
299431
* @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database.
300432
*/
301433
public fun DataFrame.Companion.readAllSqlTables(
302-
dbConfig: DatabaseConfiguration,
434+
dbConfig: DbConnectionConfig,
303435
catalogue: String? = null,
304436
limit: Int = DEFAULT_LIMIT,
305437
inferNullability: Boolean = true,
@@ -366,10 +498,7 @@ public fun DataFrame.Companion.readAllSqlTables(
366498
* @param [tableName] the name of the SQL table for which to retrieve the schema.
367499
* @return the [DataFrameSchema] object representing the schema of the SQL table
368500
*/
369-
public fun DataFrame.Companion.getSchemaForSqlTable(
370-
dbConfig: DatabaseConfiguration,
371-
tableName: String,
372-
): DataFrameSchema {
501+
public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DbConnectionConfig, tableName: String): DataFrameSchema {
373502
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
374503
return getSchemaForSqlTable(connection, tableName)
375504
}
@@ -405,10 +534,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable(connection: Connection, tabl
405534
* @param [sqlQuery] the SQL query to execute and retrieve the schema from.
406535
* @return the schema of the SQL query as a [DataFrameSchema] object.
407536
*/
408-
public fun DataFrame.Companion.getSchemaForSqlQuery(
409-
dbConfig: DatabaseConfiguration,
410-
sqlQuery: String,
411-
): DataFrameSchema {
537+
public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DbConnectionConfig, sqlQuery: String): DataFrameSchema {
412538
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
413539
return getSchemaForSqlQuery(connection, sqlQuery)
414540
}
@@ -434,6 +560,40 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQ
434560
}
435561
}
436562

563+
/**
564+
* Retrieves the schema of an SQL query result or the SQL table using the provided database configuration.
565+
*
566+
* @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from.
567+
* @return the schema of the SQL query as a [DataFrameSchema] object.
568+
*/
569+
public fun DbConnectionConfig.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema =
570+
when {
571+
isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName)
572+
573+
isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName)
574+
575+
else -> throw IllegalArgumentException(
576+
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
577+
)
578+
}
579+
580+
/**
581+
* Retrieves the schema of an SQL query result or the SQL table using the provided database configuration.
582+
*
583+
* @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from.
584+
* @return the schema of the SQL query as a [DataFrameSchema] object.
585+
*/
586+
public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema =
587+
when {
588+
isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName)
589+
590+
isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName)
591+
592+
else -> throw IllegalArgumentException(
593+
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
594+
)
595+
}
596+
437597
/**
438598
* Retrieves the schema from [ResultSet].
439599
*
@@ -448,6 +608,16 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp
448608
return buildSchemaByTableColumns(tableColumns, dbType)
449609
}
450610

611+
/**
612+
* Retrieves the schema from [ResultSet].
613+
*
614+
* NOTE: This function will not close connection and result set and not retrieve data from the result set.
615+
*
616+
* @param [dbType] the type of database that the [ResultSet] belongs to.
617+
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
618+
*/
619+
public fun ResultSet.getDataFrameSchema(dbType: DbType): DataFrameSchema = DataFrame.getSchemaForResultSet(this, dbType)
620+
451621
/**
452622
* Retrieves the schema from [ResultSet].
453623
*
@@ -465,13 +635,24 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, conne
465635
return buildSchemaByTableColumns(tableColumns, dbType)
466636
}
467637

638+
/**
639+
* Retrieves the schema from [ResultSet].
640+
*
641+
* NOTE: This function will not close connection and result set and not retrieve data from the result set.
642+
*
643+
* @param [connection] the connection to the database (it's required to extract the database type).
644+
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
645+
*/
646+
public fun ResultSet.getDataFrameSchema(connection: Connection): DataFrameSchema =
647+
DataFrame.getSchemaForResultSet(this, connection)
648+
468649
/**
469650
* Retrieves the schemas of all non-system tables in the database using the provided database configuration.
470651
*
471652
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
472653
* @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table.
473654
*/
474-
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map<String, DataFrameSchema> {
655+
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DbConnectionConfig): Map<String, DataFrameSchema> {
475656
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
476657
return getSchemaForAllSqlTables(connection)
477658
}

0 commit comments

Comments
 (0)