diff --git a/dataframe-jdbc/build.gradle.kts b/dataframe-jdbc/build.gradle.kts index 19446166c1..6e7e8dd42c 100644 --- a/dataframe-jdbc/build.gradle.kts +++ b/dataframe-jdbc/build.gradle.kts @@ -25,6 +25,7 @@ dependencies { testImplementation(libs.postgresql) testImplementation(libs.mysql) testImplementation(libs.h2db) + testImplementation(libs.mssql) testImplementation(libs.junit) testImplementation(libs.sl4j) testImplementation(libs.kotestAssertions) { diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt index d025a34b82..aae6eb995e 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt @@ -50,4 +50,14 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { * @return The corresponding Kotlin data type, or null if no mapping is found. */ public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? + + /** + * Constructs a SQL query with a limit clause. + * + * @param sqlQuery The original SQL query. + * @param limit The maximum number of rows to retrieve from the query. Default is 1. + * @return A new SQL query with the limit clause added. + */ + public open fun sqlQueryLimit(sqlQuery: String, limit: Int = 1): String = + "$sqlQuery LIMIT $limit" } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt new file mode 100644 index 0000000000..05aed59a78 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt @@ -0,0 +1,62 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata +import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema +import java.sql.ResultSet +import java.util.* +import kotlin.reflect.KType +import kotlin.reflect.full.createType + +/** + * Represents the MSSQL database type. + * + * This class provides methods to convert data from a ResultSet to the appropriate type for MSSQL, + * and to generate the corresponding column schema. + */ +public object MsSql : DbType("sqlserver") { + override val driverClassName: String + get() = "com.microsoft.sqlserver.jdbc.SQLServerDriver" + + override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? { + return null + } + + override fun isSystemTable(tableMetadata: TableMetadata): Boolean { + val locale = Locale.getDefault() + + fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true + + val schemaName = tableMetadata.schemaName + val tableName = tableMetadata.name + val catalogName = tableMetadata.catalogue + + return schemaName.containsWithLowercase("sys") || + schemaName.containsWithLowercase("information_schema") || + tableName.startsWith("sys") || + tableName.startsWith("dt") || + tableName.containsWithLowercase("sys_config") || + catalogName.containsWithLowercase("system") || + catalogName.containsWithLowercase("master") || + catalogName.containsWithLowercase("model") || + catalogName.containsWithLowercase("msdb") || + catalogName.containsWithLowercase("tempdb") + } + + override fun buildTableMetadata(tables: ResultSet): TableMetadata { + return TableMetadata( + tables.getString("table_name"), + tables.getString("table_schem"), + tables.getString("table_cat") + ) + } + + override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { + return null + } + + public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String { + sqlQuery.replace("SELECT", "SELECT TOP $limit", ignoreCase = true) + return sqlQuery + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt index 1ea06bc1e4..793b41a93e 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt @@ -17,9 +17,10 @@ public fun extractDBTypeFromUrl(url: String?): DbType { MySql.dbTypeInJdbcUrl in url -> MySql Sqlite.dbTypeInJdbcUrl in url -> Sqlite PostgreSql.dbTypeInJdbcUrl in url -> PostgreSql + MsSql.dbTypeInJdbcUrl in url -> MsSql else -> throw IllegalArgumentException( "Unsupported database type in the url: $url. " + - "Only H2, MariaDB, MySQL, SQLite and PostgreSQL are supported!" + "Only H2, MariaDB, MySQL, MSSQL, SQLite and PostgreSQL are supported!" ) } } else { diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index 2b6d0e1b63..527808fd7b 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -137,17 +137,17 @@ public fun DataFrame.Companion.readSqlTable( limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, ): AnyFrame { - var preparedQuery = "SELECT * FROM $tableName" - if (limit > 0) preparedQuery += " LIMIT $limit" - val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) + val selectAllQuery = if (limit > 0) dbType.sqlQueryLimit("SELECT * FROM $tableName", limit) + else "SELECT * FROM $tableName" + connection.createStatement().use { st -> logger.debug { "Connection with url:$url is established successfully." } st.executeQuery( - preparedQuery + selectAllQuery ).use { rs -> val tableColumns = getTableColumnsMetadata(rs) return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability) @@ -206,8 +206,7 @@ public fun DataFrame.Companion.readSqlQuery( val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - var internalSqlQuery = sqlQuery - if (limit > 0) internalSqlQuery += " LIMIT $limit" + val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery logger.debug { "Executing SQL query: $internalSqlQuery" } @@ -317,9 +316,11 @@ public fun DataFrame.Companion.readAllSqlTables( val table = dbType.buildTableMetadata(tables) if (!dbType.isSystemTable(table)) { // we filter her second time because of specific logic with SQLite and possible issues with future databases - // val tableName = if (table.catalogue != null) table.catalogue + "." + table.name else table.name - val tableName = if (catalogue != null) catalogue + "." + table.name else table.name - + val tableName = when { + catalogue != null && table.schemaName != null -> "$catalogue.${table.schemaName}.${table.name}" + catalogue != null && table.schemaName == null -> "$catalogue.${table.name}" + else -> table.name + } // TODO: both cases is schema specified or not in URL // in h2 database name is recognized as a schema name https://www.h2database.com/html/features.html#database_url // https://stackoverflow.com/questions/20896935/spring-hibernate-h2-database-schema-not-found @@ -367,11 +368,12 @@ public fun DataFrame.Companion.getSchemaForSqlTable( val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - val preparedQuery = "SELECT * FROM $tableName LIMIT 1" + val sqlQuery = "SELECT * FROM $tableName" + val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1) connection.createStatement().use { st -> st.executeQuery( - preparedQuery + selectFirstRowQuery ).use { rs -> val tableColumns = getTableColumnsMetadata(rs) return buildSchemaByTableColumns(tableColumns, dbType) @@ -532,15 +534,19 @@ private fun getTableColumnsMetadata(rs: ResultSet): MutableList() val dbConfig = DatabaseConfiguration(url = URL) @@ -675,6 +676,8 @@ class JdbcTest { saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() } + // TODO: add the same test for each particular database and refactor the scenario to the common test case + // https://github.com/Kotlin/dataframe/issues/688 @Test fun `infer nullability`() { // prepare tables and data diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt new file mode 100644 index 0000000000..7132b29290 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -0,0 +1,401 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.JdbcTest.Companion +import org.jetbrains.kotlinx.dataframe.io.db.H2 +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Ignore +import org.junit.Test +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.ResultSet +import java.sql.SQLException +import java.util.* +import kotlin.reflect.typeOf + +private const val URL = "jdbc:sqlserver://localhost:1433;encrypt=true;trustServerCertificate=true" +private const val USER_NAME = "root" +private const val PASSWORD = "pass" +private const val TEST_DATABASE_NAME = "testKDFdatabase" + +@DataSchema +interface Table1MSSSQL { + val id: Int + val bigintColumn: Long + val binaryColumn: ByteArray + val bitColumn: Boolean + val charColumn: Char + val dateColumn: Date + val datetime3Column: java.sql.Timestamp + val datetime2Column: java.sql.Timestamp + val datetimeoffset2Column: String + val decimalColumn: BigDecimal + val floatColumn: Double + val imageColumn: ByteArray? + val intColumn: Int + val moneyColumn: BigDecimal + val ncharColumn: Char + val ntextColumn: String + val numericColumn: BigDecimal + val nvarcharColumn: String + val nvarcharMaxColumn: String + val realColumn: Float + val smalldatetimeColumn: java.sql.Timestamp + val smallintColumn: Int + val smallmoneyColumn: BigDecimal + val timeColumn: java.sql.Time + val timestampColumn: java.sql.Timestamp + val tinyintColumn: Int + val uniqueidentifierColumn: Char + val varbinaryColumn: ByteArray + val varbinaryMaxColumn: ByteArray + val varcharColumn: String + val varcharMaxColumn: String + val xmlColumn: String + val sqlvariantColumn: String + val geometryColumn: String + val geographyColumn: String +} + +@Ignore +class MSSQLTest { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL, USER_NAME, PASSWORD) + + connection.createStatement().use { st -> + // Drop the test database if it exists + val dropDatabaseQuery = "IF DB_ID('$TEST_DATABASE_NAME') IS NOT NULL\n" + + "DROP DATABASE $TEST_DATABASE_NAME" + st.executeUpdate(dropDatabaseQuery) + + // Create the test database + val createDatabaseQuery = "CREATE DATABASE $TEST_DATABASE_NAME" + st.executeUpdate(createDatabaseQuery) + + // Use the newly created database + val useDatabaseQuery = "USE $TEST_DATABASE_NAME" + st.executeUpdate(useDatabaseQuery) + } + + @Language("SQL") + val createTableQuery = """ + CREATE TABLE Table1 ( + id INT NOT NULL IDENTITY PRIMARY KEY, + bigintColumn BIGINT, + binaryColumn BINARY(50), + bitColumn BIT, + charColumn CHAR(10), + dateColumn DATE, + datetime3Column DATETIME2(3), + datetime2Column DATETIME2, + datetimeoffset2Column DATETIMEOFFSET(2), + decimalColumn DECIMAL(10,2), + floatColumn FLOAT, + imageColumn IMAGE, + intColumn INT, + moneyColumn MONEY, + ncharColumn NCHAR(10), + ntextColumn NTEXT, + numericColumn NUMERIC(10,2), + nvarcharColumn NVARCHAR(50), + nvarcharMaxColumn NVARCHAR(MAX), + realColumn REAL, + smalldatetimeColumn SMALLDATETIME, + smallintColumn SMALLINT, + smallmoneyColumn SMALLMONEY, + textColumn TEXT, + timeColumn TIME, + timestampColumn DATETIME2, + tinyintColumn TINYINT, + uniqueidentifierColumn UNIQUEIDENTIFIER, + varbinaryColumn VARBINARY(50), + varbinaryMaxColumn VARBINARY(MAX), + varcharColumn VARCHAR(50), + varcharMaxColumn VARCHAR(MAX), + xmlColumn XML, + sqlvariantColumn SQL_VARIANT, + geometryColumn GEOMETRY, + geographyColumn GEOGRAPHY + ); + """ + + connection.createStatement().execute( + createTableQuery.trimIndent() + ) + + @Language("SQL") + val insertData1 = """ + INSERT INTO Table1 ( + bigintColumn, binaryColumn, bitColumn, charColumn, dateColumn, datetime3Column, datetime2Column, + datetimeoffset2Column, decimalColumn, floatColumn, imageColumn, intColumn, moneyColumn, ncharColumn, + ntextColumn, numericColumn, nvarcharColumn, nvarcharMaxColumn, realColumn, smalldatetimeColumn, + smallintColumn, smallmoneyColumn, textColumn, timeColumn, timestampColumn, tinyintColumn, + uniqueidentifierColumn, varbinaryColumn, varbinaryMaxColumn, varcharColumn, varcharMaxColumn, + xmlColumn, sqlvariantColumn, geometryColumn, geographyColumn + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() + + connection.prepareStatement(insertData1).use { st -> + for (i in 1..5) { + st.setLong(1, 123456789012345L) // bigintColumn + st.setBytes(2, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // binaryColumn + st.setBoolean(3, true) // bitColumn + st.setString(4, "Sample") // charColumn + st.setDate(5, java.sql.Date(System.currentTimeMillis())) // dateColumn + st.setTimestamp(6, java.sql.Timestamp(System.currentTimeMillis())) // datetime3Column + st.setTimestamp(7, java.sql.Timestamp(System.currentTimeMillis())) // datetime2Column + st.setTimestamp(8, java.sql.Timestamp(System.currentTimeMillis())) // datetimeoffset2Column + st.setBigDecimal(9, BigDecimal("12345.67")) // decimalColumn + st.setFloat(10, 123.45f) // floatColumn + st.setNull(11, java.sql.Types.NULL) // imageColumn (assuming nullable) + st.setInt(12, 123456) // intColumn + st.setBigDecimal(13, BigDecimal("123.45")) // moneyColumn + st.setString(14, "Sample") // ncharColumn + st.setString(15, "Sample$i text") // ntextColumn + st.setBigDecimal(16, BigDecimal("1234.56")) // numericColumn + st.setString(17, "Sample") // nvarcharColumn + st.setString(18, "Sample$i text") // nvarcharMaxColumn + st.setFloat(19, 123.45f) // realColumn + st.setTimestamp(20, java.sql.Timestamp(System.currentTimeMillis())) // smalldatetimeColumn + st.setInt(21, 123) // smallintColumn + st.setBigDecimal(22, BigDecimal("123.45")) // smallmoneyColumn + st.setString(23, "Sample$i text") // textColumn + st.setTime(24, java.sql.Time(System.currentTimeMillis())) // timeColumn + st.setTimestamp(25, java.sql.Timestamp(System.currentTimeMillis())) // timestampColumn + st.setInt(26, 123) // tinyintColumn + //st.setObject(27, null) // udtColumn (assuming nullable) + st.setObject(27, UUID.randomUUID()) // uniqueidentifierColumn + st.setBytes(28, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // varbinaryColumn + st.setBytes(29, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // varbinaryMaxColumn + st.setString(30, "Sample$i") // varcharColumn + st.setString(31, "Sample$i text") // varcharMaxColumn + st.setString(32, "Sample$i") // xmlColumn + st.setString(33, "SQL_VARIANT") // sqlvariantColumn + st.setBytes( + 34, + byteArrayOf( + 0xE6.toByte(), 0x10, 0x00, 0x00, 0x01, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x44, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x05, 0x4C, 0x0 + ) + ) // geometryColumn + st.setString(35, "POINT(1 1)") // geographyColumn + st.executeUpdate() + } + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.createStatement().use { st -> st.execute("DROP DATABASE IF EXISTS $TEST_DATABASE_NAME") } + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `basic test for reading sql tables`() { + val df1 = DataFrame.readSqlTable(connection, "table1", limit = 5).cast() + + val result = df1.filter { it[Table1MSSSQL::id] == 1 } + result[0][30] shouldBe "Sample1" + result[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L + result[0][Table1MSSSQL::bitColumn] shouldBe true + result[0][Table1MSSSQL::intColumn] shouldBe 123456 + result[0][Table1MSSSQL::ntextColumn] shouldBe "Sample1 text" + + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["bigintColumn"]!!.type shouldBe typeOf() + schema.columns["binaryColumn"]!!.type shouldBe typeOf() + schema.columns["bitColumn"]!!.type shouldBe typeOf() + schema.columns["charColumn"]!!.type shouldBe typeOf() + schema.columns["dateColumn"]!!.type shouldBe typeOf() + schema.columns["datetime3Column"]!!.type shouldBe typeOf() + schema.columns["datetime2Column"]!!.type shouldBe typeOf() + schema.columns["datetimeoffset2Column"]!!.type shouldBe typeOf() + schema.columns["decimalColumn"]!!.type shouldBe typeOf() + schema.columns["floatColumn"]!!.type shouldBe typeOf() + schema.columns["imageColumn"]!!.type shouldBe typeOf() + schema.columns["intColumn"]!!.type shouldBe typeOf() + schema.columns["moneyColumn"]!!.type shouldBe typeOf() + schema.columns["ncharColumn"]!!.type shouldBe typeOf() + schema.columns["ntextColumn"]!!.type shouldBe typeOf() + schema.columns["numericColumn"]!!.type shouldBe typeOf() + schema.columns["nvarcharColumn"]!!.type shouldBe typeOf() + schema.columns["nvarcharMaxColumn"]!!.type shouldBe typeOf() + schema.columns["realColumn"]!!.type shouldBe typeOf() + schema.columns["smalldatetimeColumn"]!!.type shouldBe typeOf() + schema.columns["smallintColumn"]!!.type shouldBe typeOf() + schema.columns["smallmoneyColumn"]!!.type shouldBe typeOf() + schema.columns["timeColumn"]!!.type shouldBe typeOf() + schema.columns["timestampColumn"]!!.type shouldBe typeOf() + schema.columns["tinyintColumn"]!!.type shouldBe typeOf() + schema.columns["uniqueidentifierColumn"]!!.type shouldBe typeOf() + schema.columns["varbinaryColumn"]!!.type shouldBe typeOf() + schema.columns["varbinaryMaxColumn"]!!.type shouldBe typeOf() + schema.columns["varcharColumn"]!!.type shouldBe typeOf() + schema.columns["varcharMaxColumn"]!!.type shouldBe typeOf() + schema.columns["xmlColumn"]!!.type shouldBe typeOf() + schema.columns["sqlvariantColumn"]!!.type shouldBe typeOf() + schema.columns["geometryColumn"]!!.type shouldBe typeOf() + schema.columns["geographyColumn"]!!.type shouldBe typeOf() + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + Table1.id, + Table1.bigintColumn + FROM Table1 + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery, limit = 3).cast() + val result = df.filter { it[Table1MSSSQL::id] == 1 } + result[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["bigintColumn"]!!.type shouldBe typeOf() + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4) + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 4 + table1Df.filter { it[Table1MSSSQL::id] > 2 }.rowsCount() shouldBe 2 + table1Df[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L + } + + // TODO: add the same test for each particular database and refactor the scenario to the common test case + // https://github.com/Kotlin/dataframe/issues/688 + @Test + fun `infer nullability`() { + // prepare tables and data + @Language("SQL") + val createTestTable1Query = """ + CREATE TABLE TestTable1 ( + id INT PRIMARY KEY, + name VARCHAR(50), + surname VARCHAR(50), + age INT NOT NULL + ) + """ + + connection.createStatement().execute(createTestTable1Query) + + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)") + + // start testing `readSqlTable` method + + // with default inferNullability: Boolean = true + val tableName = "TestTable1" + val df = DataFrame.readSqlTable(connection, tableName) + df.schema().columns["id"]!!.type shouldBe typeOf() + df.schema().columns["name"]!!.type shouldBe typeOf() + df.schema().columns["surname"]!!.type shouldBe typeOf() + df.schema().columns["age"]!!.type shouldBe typeOf() + + val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) + dataSchema.columns.size shouldBe 4 + dataSchema.columns["id"]!!.type shouldBe typeOf() + dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema.columns["surname"]!!.type shouldBe typeOf() + dataSchema.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false) + df1.schema().columns["id"]!!.type shouldBe typeOf() + df1.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df1.schema().columns["surname"]!!.type shouldBe typeOf() + df1.schema().columns["age"]!!.type shouldBe typeOf() + + // end testing `readSqlTable` method + + // start testing `readSQLQuery` method + + // ith default inferNullability: Boolean = true + @Language("SQL") + val sqlQuery = """ + SELECT name, surname, age FROM TestTable1 + """.trimIndent() + + val df2 = DataFrame.readSqlQuery(connection, sqlQuery) + df2.schema().columns["name"]!!.type shouldBe typeOf() + df2.schema().columns["surname"]!!.type shouldBe typeOf() + df2.schema().columns["age"]!!.type shouldBe typeOf() + + val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + dataSchema2.columns.size shouldBe 3 + dataSchema2.columns["name"]!!.type shouldBe typeOf() + dataSchema2.columns["surname"]!!.type shouldBe typeOf() + dataSchema2.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false) + df3.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df3.schema().columns["surname"]!!.type shouldBe typeOf() + df3.schema().columns["age"]!!.type shouldBe typeOf() + + // end testing `readSQLQuery` method + + // start testing `readResultSet` method + + connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st -> + @Language("SQL") + val selectStatement = "SELECT * FROM TestTable1" + + st.executeQuery(selectStatement).use { rs -> + // ith default inferNullability: Boolean = true + val df4 = DataFrame.readResultSet(rs, H2) + df4.schema().columns["id"]!!.type shouldBe typeOf() + df4.schema().columns["name"]!!.type shouldBe typeOf() + df4.schema().columns["surname"]!!.type shouldBe typeOf() + df4.schema().columns["age"]!!.type shouldBe typeOf() + + rs.beforeFirst() + + val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2) + dataSchema3.columns.size shouldBe 4 + dataSchema3.columns["id"]!!.type shouldBe typeOf() + dataSchema3.columns["name"]!!.type shouldBe typeOf() + dataSchema3.columns["surname"]!!.type shouldBe typeOf() + dataSchema3.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + rs.beforeFirst() + + val df5 = DataFrame.readResultSet(rs, H2, inferNullability = false) + df5.schema().columns["id"]!!.type shouldBe typeOf() + df5.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df5.schema().columns["surname"]!!.type shouldBe typeOf() + df5.schema().columns["age"]!!.type shouldBe typeOf() + } + } + // end testing `readResultSet` method + + connection.createStatement().execute("DROP TABLE TestTable1") + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 402002d817..38e1e045ef 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -31,6 +31,7 @@ fuel = "2.3.1" poi = "5.2.5" mariadb = "3.3.2" h2db = "2.2.224" +mssql = "12.6.1.jre11" mysql = "8.3.0" postgresql = "42.7.2" sqlite = "3.45.1.0" @@ -75,6 +76,7 @@ fuel = { group = "com.github.kittinunf.fuel", name = "fuel", version.ref = "fuel poi = { group = "org.apache.poi", name = "poi", version.ref = "poi" } mariadb = { group = "org.mariadb.jdbc", name = "mariadb-java-client", version.ref = "mariadb" } h2db = { group = "com.h2database", name = "h2", version.ref = "h2db" } +mssql = { group = "com.microsoft.sqlserver", name = "mssql-jdbc", version.ref = "mssql" } mysql = { group = "com.mysql", name = "mysql-connector-j", version.ref = "mysql" } postgresql = { group = "org.postgresql", name = "postgresql", version.ref = "postgresql" } sqlite = { group = "org.xerial", name = "sqlite-jdbc", version.ref = "sqlite" }