Skip to content

Commit c40fb04

Browse files
authored
Add a support for H2 modes (#720)
* Refactor H2 database util to inherit DbType H2 database util has been refactored to inherit DbType, allowing for the use of different dialects. Updated the relevant test cases and added a function to correctly identify the dialect based on the URL. Added error handling for unsupported dialects. * Split the tests for databases on two parts - local and H2 oriented * Fixed tests * Add jts-core library and improve code documentation The jts-core library was added to the project dependencies, allowing for usage in the codebase. Moreover, some improvements in the documentation of the code were made. Specifically, better explanations were provided for error cases in the `extractDBTypeFromConnection` and `extractDBTypeFromUrl` functions, and extensive documentation was added for the companion object in the H2.kt file. * Refactor code formatting across several DF-JDBC files Various minor formatting changes have been applied to improve code readability. This includes rearranging import statements in the 'postgresTest.kt' file, removing superfluous empty lines in 'mssqlTest.kt', and adjusting white-spacing for improved consistency in 'util.kt' and 'H2.kt'. * Update import statements across multiple files * Refactor H2 class and extend tests Performed a refinement in the H2 class and implemented test coverage for specific conditions. The H2 class now uses the class reference for comparison instead of the simple name, eliminating string comparison. Additionally, a test has been added to check that an exception is properly thrown when specifying an H2 database with H2 dialect. Minor import adjustments were also made in the readJdbc and mssqlTest files.
1 parent cf234e7 commit c40fb04

File tree

18 files changed

+1689
-58
lines changed

18 files changed

+1689
-58
lines changed

dataframe-jdbc/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies {
2828
testImplementation(libs.mssql)
2929
testImplementation(libs.junit)
3030
testImplementation(libs.sl4j)
31+
testImplementation(libs.jts)
3132
testImplementation(libs.kotestAssertions) {
3233
exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8")
3334
}

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,67 @@ import kotlin.reflect.KType
1010
/**
1111
* Represents the H2 database type.
1212
*
13-
* This class provides methods to convert data from a ResultSet to the appropriate type for H2,
13+
* This class provides methods to convert data from a ResultSet to the appropriate type for H2
1414
* and to generate the corresponding column schema.
1515
*
16-
* NOTE: All date and timestamp related types are converted to String to avoid java.sql.* types.
16+
* NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types.
1717
*/
18-
public object H2 : DbType("h2") {
18+
public class H2(public val dialect: DbType = MySql) : DbType("h2") {
19+
init {
20+
require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" }
21+
}
22+
23+
/**
24+
* It contains constants related to different database modes.
25+
*
26+
* The mode value is used in the [extractDBTypeFromConnection] function to determine the corresponding `DbType` for the H2 database connection URL.
27+
* For example, if the URL contains the mode value "MySQL", the H2 instance with the MySQL database type is returned.
28+
* Otherwise, the `DbType` is determined based on the URL without the mode value.
29+
*
30+
* @see [extractDBTypeFromConnection]
31+
* @see [createH2Instance]
32+
*/
33+
public companion object {
34+
/** It represents the mode value "MySQL" for the H2 database. */
35+
public const val MODE_MYSQL: String = "MySQL"
36+
37+
/** It represents the mode value "PostgreSQL" for the H2 database. */
38+
public const val MODE_POSTGRESQL: String = "PostgreSQL"
39+
40+
/** It represents the mode value "MSSQLServer" for the H2 database. */
41+
public const val MODE_MSSQLSERVER: String = "MSSQLServer"
42+
43+
/** It represents the mode value "MariaDB" for the H2 database. */
44+
public const val MODE_MARIADB: String = "MariaDB"
45+
}
46+
1947
override val driverClassName: String
2048
get() = "org.h2.Driver"
2149

2250
override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
23-
return null
51+
return dialect.convertSqlTypeToColumnSchemaValue(tableColumnMetadata)
2452
}
2553

2654
override fun isSystemTable(tableMetadata: TableMetadata): Boolean {
27-
return tableMetadata.name.lowercase(Locale.getDefault()).contains("sys_") ||
28-
tableMetadata.schemaName?.lowercase(Locale.getDefault())?.contains("information_schema") ?: false
55+
val locale = Locale.getDefault()
56+
fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true
57+
val schemaName = tableMetadata.schemaName
58+
59+
// could be extended for other symptoms of the system tables for H2
60+
val isH2SystemTable = schemaName.containsWithLowercase("information_schema")
61+
62+
return isH2SystemTable || dialect.isSystemTable(tableMetadata)
2963
}
3064

3165
override fun buildTableMetadata(tables: ResultSet): TableMetadata {
32-
return TableMetadata(
33-
tables.getString("TABLE_NAME"),
34-
tables.getString("TABLE_SCHEM"),
35-
tables.getString("TABLE_CAT")
36-
)
66+
return dialect.buildTableMetadata(tables)
3767
}
3868

3969
override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
40-
return null
70+
return dialect.convertSqlTypeToKType(tableColumnMetadata)
71+
}
72+
73+
public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String {
74+
return dialect.sqlQueryLimit(sqlQuery, limit)
4175
}
4276
}

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
44
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
55
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
66
import java.sql.ResultSet
7-
import java.util.*
7+
import java.util.Locale
88
import kotlin.reflect.KType
9-
import kotlin.reflect.full.createType
109

1110
/**
1211
* Represents the MSSQL database type.

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,75 @@
11
package org.jetbrains.kotlinx.dataframe.io.db
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
4+
import java.sql.Connection
35
import java.sql.SQLException
6+
import java.util.Locale
7+
8+
private val logger = KotlinLogging.logger {}
9+
10+
/**
11+
* Extracts the database type from the given connection.
12+
*
13+
* @param [connection] the database connection.
14+
* @return the corresponding [DbType].
15+
* @throws [IllegalStateException] if URL information is missing in connection meta-data.
16+
* @throws [IllegalArgumentException] if the URL specifies an unsupported database type.
17+
* @throws [SQLException] if the URL is null.
18+
*/
19+
public fun extractDBTypeFromConnection(connection: Connection): DbType {
20+
val url = connection.metaData?.url ?: throw IllegalStateException("URL information is missing in connection meta data!")
21+
logger.info { "Processing DB type extraction for connection url: $url" }
22+
23+
return if (url.contains(H2().dbTypeInJdbcUrl)) {
24+
// works only for H2 version 2
25+
val modeQuery = "SELECT SETTING_VALUE FROM INFORMATION_SCHEMA.SETTINGS WHERE SETTING_NAME = 'MODE'"
26+
var mode = ""
27+
connection.createStatement().use { st ->
28+
st.executeQuery(
29+
modeQuery
30+
).use { rs ->
31+
if (rs.next()) {
32+
mode = rs.getString("SETTING_VALUE")
33+
logger.debug { "Fetched H2 DB mode: $mode" }
34+
} else {
35+
throw IllegalStateException("The information about H2 mode is not found in the H2 meta-data!")
36+
}
37+
}
38+
}
39+
40+
// H2 doesn't support MariaDB and SQLite
41+
when (mode.lowercase(Locale.getDefault())) {
42+
H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql)
43+
H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql)
44+
H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql)
45+
H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb)
46+
else -> {
47+
val message = "Unsupported database type in the url: $url. " +
48+
"Only MySQL, MariaDB, MSSQL and PostgreSQL are supported!"
49+
logger.error { message }
50+
51+
throw IllegalArgumentException(message)
52+
}
53+
}
54+
} else {
55+
val dbType = extractDBTypeFromUrl(url)
56+
logger.info { "Identified DB type as $dbType from url: $url" }
57+
dbType
58+
}
59+
}
460

561
/**
662
* Extracts the database type from the given JDBC URL.
763
*
864
* @param [url] the JDBC URL.
965
* @return the corresponding [DbType].
10-
* @throws RuntimeException if the url is null.
66+
* @throws [RuntimeException] if the url is null.
1167
*/
1268
public fun extractDBTypeFromUrl(url: String?): DbType {
1369
if (url != null) {
70+
val helperH2Instance = H2()
1471
return when {
15-
H2.dbTypeInJdbcUrl in url -> H2
72+
helperH2Instance.dbTypeInJdbcUrl in url -> createH2Instance(url)
1673
MariaDb.dbTypeInJdbcUrl in url -> MariaDb
1774
MySql.dbTypeInJdbcUrl in url -> MySql
1875
Sqlite.dbTypeInJdbcUrl in url -> Sqlite
@@ -28,6 +85,37 @@ public fun extractDBTypeFromUrl(url: String?): DbType {
2885
}
2986
}
3087

88+
/**
89+
* Creates an instance of DbType based on the provided JDBC URL.
90+
*
91+
* @param [url] The JDBC URL representing the database connection.
92+
* @return The corresponding [DbType] instance.
93+
* @throws [IllegalArgumentException] if the provided URL does not contain a valid mode.
94+
*/
95+
private fun createH2Instance(url: String): DbType {
96+
val modePattern = "MODE=(.*?);".toRegex()
97+
val matchResult = modePattern.find(url)
98+
99+
val mode: String = if (matchResult != null && matchResult.groupValues.size == 2) {
100+
matchResult.groupValues[1]
101+
} else {
102+
throw IllegalArgumentException("The provided URL `$url` does not contain a valid mode.")
103+
}
104+
105+
// H2 doesn't support MariaDB and SQLite
106+
return when (mode.lowercase(Locale.getDefault())) {
107+
H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql)
108+
H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql)
109+
H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql)
110+
H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb)
111+
112+
else -> throw IllegalArgumentException(
113+
"Unsupported database mode: $mode. " +
114+
"Only MySQL, MariaDB, MSSQL, PostgreSQL modes are supported!"
115+
)
116+
}
117+
}
118+
31119
/**
32120
* Retrieves the driver class name from the given JDBC URL.
33121
*

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
77
import org.jetbrains.kotlinx.dataframe.api.Infer
88
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
99
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
10-
import org.jetbrains.kotlinx.dataframe.io.db.DbType
11-
import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromUrl
10+
import org.jetbrains.kotlinx.dataframe.io.db.*
1211
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
1312
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
1413
import java.math.BigDecimal
@@ -138,7 +137,7 @@ public fun DataFrame.Companion.readSqlTable(
138137
inferNullability: Boolean = true,
139138
): AnyFrame {
140139
val url = connection.metaData.url
141-
val dbType = extractDBTypeFromUrl(url)
140+
val dbType = extractDBTypeFromConnection(connection)
142141

143142
val selectAllQuery = if (limit > 0) dbType.sqlQueryLimit("SELECT * FROM $tableName", limit)
144143
else "SELECT * FROM $tableName"
@@ -203,8 +202,7 @@ public fun DataFrame.Companion.readSqlQuery(
203202
"Also it should not contain any separators like `;`."
204203
}
205204

206-
val url = connection.metaData.url
207-
val dbType = extractDBTypeFromUrl(url)
205+
val dbType = extractDBTypeFromConnection(connection)
208206

209207
val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery
210208

@@ -283,8 +281,7 @@ public fun DataFrame.Companion.readResultSet(
283281
limit: Int = DEFAULT_LIMIT,
284282
inferNullability: Boolean = true,
285283
): AnyFrame {
286-
val url = connection.metaData.url
287-
val dbType = extractDBTypeFromUrl(url)
284+
val dbType = extractDBTypeFromConnection(connection)
288285

289286
return readResultSet(resultSet, dbType, limit, inferNullability)
290287
}
@@ -329,8 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables(
329326
inferNullability: Boolean = true,
330327
): Map<String, AnyFrame> {
331328
val metaData = connection.metaData
332-
val url = connection.metaData.url
333-
val dbType = extractDBTypeFromUrl(url)
329+
val dbType = extractDBTypeFromConnection(connection)
334330

335331
// exclude a system and other tables without data, but it looks like it is supported badly for many databases
336332
val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE"))
@@ -390,8 +386,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable(
390386
connection: Connection,
391387
tableName: String
392388
): DataFrameSchema {
393-
val url = connection.metaData.url
394-
val dbType = extractDBTypeFromUrl(url)
389+
val dbType = extractDBTypeFromConnection(connection)
395390

396391
val sqlQuery = "SELECT * FROM $tableName"
397392
val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1)
@@ -432,8 +427,7 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(
432427
* @see DriverManager.getConnection
433428
*/
434429
public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema {
435-
val url = connection.metaData.url
436-
val dbType = extractDBTypeFromUrl(url)
430+
val dbType = extractDBTypeFromConnection(connection)
437431

438432
connection.createStatement().use { st ->
439433
st.executeQuery(sqlQuery).use { rs ->
@@ -468,8 +462,7 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp
468462
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
469463
*/
470464
public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema {
471-
val url = connection.metaData.url
472-
val dbType = extractDBTypeFromUrl(url)
465+
val dbType = extractDBTypeFromConnection(connection)
473466

474467
val tableColumns = getTableColumnsMetadata(resultSet)
475468
return buildSchemaByTableColumns(tableColumns, dbType)
@@ -495,8 +488,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfig
495488
*/
496489
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map<String, DataFrameSchema> {
497490
val metaData = connection.metaData
498-
val url = connection.metaData.url
499-
val dbType = extractDBTypeFromUrl(url)
491+
val dbType = extractDBTypeFromConnection(connection)
500492

501493
val tableTypes = arrayOf("TABLE")
502494
// exclude a system and other tables without data

0 commit comments

Comments
 (0)