Skip to content

Commit 53b996e

Browse files
authored
Update SQL all table/schemas reading functions to return maps with table names (#718)
* Update SQL table reading functions to return maps Refactored the SQL table reading functions to return map data structures instead of list. This change helps to easily correlate each dataframe with its underlying table. Additionally, the function comments and test cases were updated to match this change. * Correct typos and ignore MSSQLTest Typographical errors in the comments of the readJdbc.kt file were corrected for clarity. Additionally, the MSSQLTest class in the mssqlTest.kt file was annotated with "@ignore" to skip the test in subsequent testing runs.
1 parent 2c4fdfc commit 53b996e

File tree

8 files changed

+50
-39
lines changed

8 files changed

+50
-39
lines changed

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -290,33 +290,35 @@ public fun DataFrame.Companion.readResultSet(
290290
}
291291

292292
/**
293-
* Reads all tables from the given database using the provided database configuration and limit.
293+
* Reads all non-system tables from a database and returns them
294+
* as a map of SQL tables and corresponding dataframes using the provided database configuration and limit.
294295
*
295296
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
296297
* @param [limit] the maximum number of rows to read from each table.
297298
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
298299
* @param [inferNullability] indicates how the column nullability should be inferred.
299-
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
300+
* @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database.
300301
*/
301302
public fun DataFrame.Companion.readAllSqlTables(
302303
dbConfig: DatabaseConfiguration,
303304
catalogue: String? = null,
304305
limit: Int = DEFAULT_LIMIT,
305306
inferNullability: Boolean = true,
306-
): List<AnyFrame> {
307+
): Map<String, AnyFrame> {
307308
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
308309
return readAllSqlTables(connection, catalogue, limit, inferNullability)
309310
}
310311
}
311312

312313
/**
313-
* Reads all non-system tables from a database and returns them as a list of data frames.
314+
* Reads all non-system tables from a database and returns them
315+
* as a map of SQL tables and corresponding dataframes.
314316
*
315317
* @param [connection] the database connection to read tables from.
316318
* @param [limit] the maximum number of rows to read from each table.
317319
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
318320
* @param [inferNullability] indicates how the column nullability should be inferred.
319-
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
321+
* @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database.
320322
*
321323
* @see DriverManager.getConnection
322324
*/
@@ -325,20 +327,20 @@ public fun DataFrame.Companion.readAllSqlTables(
325327
catalogue: String? = null,
326328
limit: Int = DEFAULT_LIMIT,
327329
inferNullability: Boolean = true,
328-
): List<AnyFrame> {
330+
): Map<String, AnyFrame> {
329331
val metaData = connection.metaData
330332
val url = connection.metaData.url
331333
val dbType = extractDBTypeFromUrl(url)
332334

333-
// exclude a system and other tables without data, but it looks like it supported badly for many databases
335+
// exclude a system and other tables without data, but it looks like it is supported badly for many databases
334336
val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE"))
335337

336-
val dataFrames = mutableListOf<AnyFrame>()
338+
val dataFrames = mutableMapOf<String, AnyFrame>()
337339

338340
while (tables.next()) {
339341
val table = dbType.buildTableMetadata(tables)
340342
if (!dbType.isSystemTable(table)) {
341-
// we filter her second time because of specific logic with SQLite and possible issues with future databases
343+
// we filter here a second time because of specific logic with SQLite and possible issues with future databases
342344
val tableName = when {
343345
catalogue != null && table.schemaName != null -> "$catalogue.${table.schemaName}.${table.name}"
344346
catalogue != null && table.schemaName == null -> "$catalogue.${table.name}"
@@ -351,7 +353,7 @@ public fun DataFrame.Companion.readAllSqlTables(
351353
logger.debug { "Reading table: $tableName" }
352354

353355
val dataFrame = readSqlTable(connection, tableName, limit, inferNullability)
354-
dataFrames += dataFrame
356+
dataFrames += tableName to dataFrame
355357
logger.debug { "Finished reading table: $tableName" }
356358
}
357359
}
@@ -474,24 +476,24 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, conne
474476
}
475477

476478
/**
477-
* Retrieves the schema of all non-system tables in the database using the provided database configuration.
479+
* Retrieves the schemas of all non-system tables in the database using the provided database configuration.
478480
*
479481
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
480-
* @return a list of [DataFrameSchema] objects representing the schema of each non-system table.
482+
* @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table.
481483
*/
482-
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List<DataFrameSchema> {
484+
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map<String, DataFrameSchema> {
483485
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
484486
return getSchemaForAllSqlTables(connection)
485487
}
486488
}
487489

488490
/**
489-
* Retrieves the schema of all non-system tables in the database using the provided database connection.
491+
* Retrieves the schemas of all non-system tables in the database using the provided database connection.
490492
*
491493
* @param [connection] the database connection.
492-
* @return a list of [DataFrameSchema] objects representing the schema of each non-system table.
494+
* @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table.
493495
*/
494-
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): List<DataFrameSchema> {
496+
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map<String, DataFrameSchema> {
495497
val metaData = connection.metaData
496498
val url = connection.metaData.url
497499
val dbType = extractDBTypeFromUrl(url)
@@ -500,14 +502,15 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection):
500502
// exclude a system and other tables without data
501503
val tables = metaData.getTables(null, null, null, tableTypes)
502504

503-
val dataFrameSchemas = mutableListOf<DataFrameSchema>()
505+
val dataFrameSchemas = mutableMapOf<String, DataFrameSchema>()
504506

505507
while (tables.next()) {
506508
val jdbcTable = dbType.buildTableMetadata(tables)
507509
if (!dbType.isSystemTable(jdbcTable)) {
508-
// we filter her second time because of specific logic with SQLite and possible issues with future databases
509-
val dataFrameSchema = getSchemaForSqlTable(connection, jdbcTable.name)
510-
dataFrameSchemas += dataFrameSchema
510+
// we filter her a second time because of specific logic with SQLite and possible issues with future databases
511+
val tableName = jdbcTable.name
512+
val dataFrameSchema = getSchemaForSqlTable(connection, tableName)
513+
dataFrameSchemas += tableName to dataFrameSchema
511514
}
512515
}
513516

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,11 @@ class JdbcTest {
597597

598598
@Test
599599
fun `read from all tables`() {
600-
val dataframes = DataFrame.readAllSqlTables(connection)
600+
val dataFrameMap = DataFrame.readAllSqlTables(connection)
601+
dataFrameMap.containsKey("Customer") shouldBe true
602+
dataFrameMap.containsKey("Sale") shouldBe true
603+
604+
val dataframes = dataFrameMap.values.toList()
601605

602606
val customerDf = dataframes[0].cast<Customer>()
603607

@@ -611,7 +615,7 @@ class JdbcTest {
611615
saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3
612616
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0
613617

614-
val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1)
618+
val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1).values.toList()
615619

616620
val customerDf1 = dataframes1[0].cast<Customer>()
617621

@@ -625,7 +629,11 @@ class JdbcTest {
625629
saleDf1.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1
626630
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0
627631

628-
val dataSchemas = DataFrame.getSchemaForAllSqlTables(connection)
632+
val dataFrameSchemaMap = DataFrame.getSchemaForAllSqlTables(connection)
633+
dataFrameSchemaMap.containsKey("Customer") shouldBe true
634+
dataFrameSchemaMap.containsKey("Sale") shouldBe true
635+
636+
val dataSchemas = dataFrameSchemaMap.values.toList()
629637

630638
val customerDataSchema = dataSchemas[0]
631639
customerDataSchema.columns.size shouldBe 3
@@ -637,7 +645,7 @@ class JdbcTest {
637645
saleDataSchema.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
638646

639647
val dbConfig = DatabaseConfiguration(url = URL)
640-
val dataframes2 = DataFrame.readAllSqlTables(dbConfig)
648+
val dataframes2 = DataFrame.readAllSqlTables(dbConfig).values.toList()
641649

642650
val customerDf2 = dataframes2[0].cast<Customer>()
643651

@@ -651,7 +659,7 @@ class JdbcTest {
651659
saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3
652660
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0
653661

654-
val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1)
662+
val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1).values.toList()
655663

656664
val customerDf3 = dataframes3[0].cast<Customer>()
657665

@@ -665,7 +673,7 @@ class JdbcTest {
665673
saleDf3.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1
666674
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0
667675

668-
val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig)
676+
val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig).values.toList()
669677

670678
val customerDataSchema1 = dataSchemas1[0]
671679
customerDataSchema1.columns.size shouldBe 3

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ class MariadbTest {
370370

371371
@Test
372372
fun `read from all tables`() {
373-
val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000)
373+
val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000).values.toList()
374374

375375
val table1Df = dataframes[0].cast<Table1MariaDb>()
376376

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ class MSSQLTest {
277277

278278
@Test
279279
fun `read from all tables`() {
280-
val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4)
280+
val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4).values.toList()
281281

282282
val table1Df = dataframes[0].cast<Table1MSSSQL>()
283283

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ class MySqlTest {
370370

371371
@Test
372372
fun `read from all tables`() {
373-
val dataframes = DataFrame.readAllSqlTables(connection)
373+
val dataframes = DataFrame.readAllSqlTables(connection).values.toList()
374374

375375
val table1Df = dataframes[0].cast<Table1MySql>()
376376

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ class PostgresTest {
298298

299299
@Test
300300
fun `read from all tables`() {
301-
val dataframes = DataFrame.readAllSqlTables(connection)
301+
val dataframes = DataFrame.readAllSqlTables(connection).values.toList()
302302

303303
val table1Df = dataframes[0].cast<Table1>()
304304

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ class SqliteTest {
193193

194194
@Test
195195
fun `read from all tables`() {
196-
val dataframes = DataFrame.readAllSqlTables(connection)
196+
val dataframes = DataFrame.readAllSqlTables(connection).values.toList()
197197

198198
val customerDf = dataframes[0].cast<CustomerSQLite>()
199199

docs/StardustDocs/topics/readSqlDatabases.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ In the second, be sure that you can establish a connection to the database.
5959

6060
For this, usually, you need to have three things: a URL to a database, a username and a password.
6161

62-
Call one of the following functions to obtain data from a database and transform it to the dataframe.
62+
Call one of the following functions to collect data from a database and transform it to the dataframe.
6363

6464
For example, if you have a local PostgreSQL database named as `testDatabase` with table `Customer`,
6565
you could read first 100 rows and print the data just copying the code below:
@@ -105,7 +105,7 @@ Next, import `Kotlin DataFrame` library in the cell below.
105105
**NOTE:** The order of cell execution is important,
106106
the dataframe library is waiting for a JDBC driver to force classloading.
107107

108-
Find full example Notebook [here](https://github.com/zaleslaw/KotlinDataFrame-SQL-Examples/blob/master/notebooks/imdb.ipynb).
108+
Find a full example Notebook [here](https://github.com/zaleslaw/KotlinDataFrame-SQL-Examples/blob/master/notebooks/imdb.ipynb).
109109

110110

111111
## Reading Specific Tables
@@ -315,9 +315,9 @@ connection.close()
315315
These functions read all data from all tables in the connected database.
316316
Variants with a limit parameter restrict how many rows will be read from each table.
317317

318-
**readAllSqlTables(connection: Connection): List\<AnyFrame>**
318+
**readAllSqlTables(connection: Connection): Map\<String, AnyFrame>**
319319

320-
Retrieves data from all the non-system tables in the SQL database and returns them as a list of AnyFrame objects.
320+
Retrieves data from all the non-system tables in the SQL database and returns them as a map of table names to AnyFrame objects.
321321

322322
The `dbConfig: DatabaseConfiguration` parameter represents the configuration for a database connection,
323323
created under the hood and managed by the library. Typically, it requires a URL, username and password.
@@ -330,7 +330,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS
330330
val dataframes = DataFrame.readAllSqlTables(dbConfig)
331331
```
332332

333-
**readAllSqlTables(connection: Connection, limit: Int): List\<AnyFrame>**
333+
**readAllSqlTables(connection: Connection, limit: Int): Map\<String, AnyFrame>**
334334

335335
A variant of the previous function,
336336
but with an added `limit: Int` parameter that allows setting the maximum number of records to be read from each table.
@@ -493,10 +493,10 @@ connection.close()
493493
These functions return a list of all [`DataFrameSchema`](schema.md) from all the non-system tables in the SQL database.
494494
They can be called with either a database configuration or a connection.
495495

496-
**getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List\<DataFrameSchema>**
496+
**getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map\<String, DataFrameSchema>**
497497

498498
This function retrieves the schema of all tables from an SQL database
499-
and returns them as a list of [`DataFrameSchema`](schema.md).
499+
and returns them as a map of table names to [`DataFrameSchema`](schema.md) objects.
500500

501501
The `dbConfig: DatabaseConfiguration` parameter represents the configuration for a database connection,
502502
created under the hood and managed by the library. Typically, it requires a URL, username and password.
@@ -509,7 +509,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS
509509
val schemas = DataFrame.getSchemaForAllSqlTables(dbConfig)
510510
```
511511

512-
**getSchemaForAllSqlTables(connection: Connection): List\<DataFrameSchema>**
512+
**getSchemaForAllSqlTables(connection: Connection): Map\<String, DataFrameSchema>**
513513

514514
This function retrieves the schema of all tables using a JDBC connection: `Connection` object
515515
and returns them as a list of [`DataFrameSchema`](schema.md).

0 commit comments

Comments
 (0)