Skip to content

Commit f7cd1a1

Browse files
committed
Fix review
1 parent 265b893 commit f7cd1a1

File tree

1 file changed

+42
-48
lines changed
  • dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2

1 file changed

+42
-48
lines changed

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,25 @@ class JdbcTest {
9696
@BeforeClass
9797
@JvmStatic
9898
fun setUpClass() {
99+
initializeConnection()
100+
initializeDataSource()
101+
createTablesAndData()
102+
}
103+
104+
private fun initializeConnection() {
99105
connection = DriverManager.getConnection(URL)
106+
}
100107

101-
// Initialize DataSource
108+
private fun initializeDataSource() {
102109
val config = HikariConfig().apply {
103110
jdbcUrl = URL
104111
maximumPoolSize = 10
105112
minimumIdle = 2
106113
}
107114
dataSource = HikariDataSource(config)
115+
}
108116

117+
private fun createTablesAndData() {
109118
// Create table Customer
110119
@Language("SQL")
111120
val createCustomerTableQuery = """
@@ -155,24 +164,15 @@ class JdbcTest {
155164
}
156165

157166
// Helper assertion functions
158-
private fun assertCustomerData(df: AnyFrame) {
159-
val casted = df.cast<Customer>()
160-
casted.rowsCount() shouldBe 4
161-
casted.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
162-
casted[0][1] shouldBe "John"
163-
}
164-
165-
private fun assertCustomerDataWithLimit(df: AnyFrame) {
166-
val casted = df.cast<Customer>()
167-
casted.rowsCount() shouldBe 1
168-
casted.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1
169-
casted[0][1] shouldBe "John"
170-
}
171-
172-
private fun assertCustomerDataWithLimitTwo(df: AnyFrame) {
167+
private fun assertCustomerData(df: AnyFrame, expectedRows: Int = 4) {
173168
val casted = df.cast<Customer>()
174-
casted.rowsCount() shouldBe 2
175-
casted.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1
169+
casted.rowsCount() shouldBe expectedRows
170+
val expectedOlderThan30 = when (expectedRows) {
171+
4 -> 2
172+
2 -> 1
173+
else -> 1 // for 1 row or other small limits in tests
174+
}
175+
casted.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe expectedOlderThan30
176176
casted[0][1] shouldBe "John"
177177
}
178178

@@ -181,16 +181,10 @@ class JdbcTest {
181181
schema.columns["name"]!!.type shouldBe typeOf<String?>()
182182
}
183183

184-
private fun assertCustomerSalesData(df: AnyFrame) {
185-
val casted = df.cast<CustomerSales>()
186-
casted.rowsCount() shouldBe 2
187-
casted.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1
188-
casted[0][0] shouldBe "John"
189-
}
190-
191-
private fun assertCustomerSalesDataWithLimit(df: AnyFrame) {
184+
private fun assertCustomerSalesData(df: AnyFrame, expectedRows: Int = 2) {
192185
val casted = df.cast<CustomerSales>()
193-
casted.rowsCount() shouldBe 1
186+
casted.rowsCount() shouldBe expectedRows
187+
// In current tests, regardless of limit (2 or 1), the count of totalSalesAmount > 100 is 1
194188
casted.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1
195189
casted[0][0] shouldBe "John"
196190
}
@@ -425,7 +419,7 @@ class JdbcTest {
425419
assertCustomerData(df)
426420

427421
val df1 = DataFrame.readSqlTable(connection, tableName, 1)
428-
assertCustomerDataWithLimit(df1)
422+
assertCustomerData(df1, 1)
429423

430424
val dataSchema = DataFrameSchema.readSqlTable(connection, tableName)
431425
assertCustomerSchema(dataSchema)
@@ -435,7 +429,7 @@ class JdbcTest {
435429
assertCustomerData(df2)
436430

437431
val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1)
438-
assertCustomerDataWithLimit(df3)
432+
assertCustomerData(df3, 1)
439433

440434
val dataSchema1 = DataFrameSchema.readSqlTable(dbConfig, tableName)
441435
assertCustomerSchema(dataSchema1)
@@ -448,7 +442,7 @@ class JdbcTest {
448442
assertCustomerData(df)
449443

450444
val df1 = connection.readDataFrame(tableName, 1)
451-
assertCustomerDataWithLimit(df1)
445+
assertCustomerData(df1, 1)
452446

453447
val dataSchema = connection.readDataFrameSchema(tableName)
454448
assertCustomerSchema(dataSchema)
@@ -458,7 +452,7 @@ class JdbcTest {
458452
assertCustomerData(df2)
459453

460454
val df3 = dbConfig.readDataFrame(tableName, 1)
461-
assertCustomerDataWithLimit(df3)
455+
assertCustomerData(df3, 1)
462456

463457
val dataSchema1 = dbConfig.readDataFrameSchema(tableName)
464458
assertCustomerSchema(dataSchema1)
@@ -470,11 +464,11 @@ class JdbcTest {
470464

471465
repeat(10) {
472466
val df1 = DataFrame.readSqlTable(connection, tableName, 2)
473-
assertCustomerDataWithLimitTwo(df1)
467+
assertCustomerData(df1, 2)
474468

475469
val dbConfig = DbConnectionConfig(url = URL)
476470
val df2 = DataFrame.readSqlTable(dbConfig, tableName, 2)
477-
assertCustomerDataWithLimitTwo(df2)
471+
assertCustomerData(df2, 2)
478472
}
479473
}
480474

@@ -491,7 +485,7 @@ class JdbcTest {
491485
rs.beforeFirst()
492486

493487
val df1 = DataFrame.readResultSet(rs, H2(MySql), 1)
494-
assertCustomerDataWithLimit(df1)
488+
assertCustomerData(df1, 1)
495489

496490
rs.beforeFirst()
497491

@@ -506,7 +500,7 @@ class JdbcTest {
506500
rs.beforeFirst()
507501

508502
val df3 = DataFrame.readResultSet(rs, connection, 1)
509-
assertCustomerDataWithLimit(df3)
503+
assertCustomerData(df3, 1)
510504

511505
rs.beforeFirst()
512506

@@ -529,7 +523,7 @@ class JdbcTest {
529523
rs.beforeFirst()
530524

531525
val df1 = rs.readDataFrame(H2(MySql), 1)
532-
assertCustomerDataWithLimit(df1)
526+
assertCustomerData(df1, 1)
533527

534528
rs.beforeFirst()
535529

@@ -544,7 +538,7 @@ class JdbcTest {
544538
rs.beforeFirst()
545539

546540
val df3 = rs.readDataFrame(connection, 1)
547-
assertCustomerDataWithLimit(df3)
541+
assertCustomerData(df3, 1)
548542

549543
rs.beforeFirst()
550544

@@ -566,12 +560,12 @@ class JdbcTest {
566560
rs.beforeFirst()
567561

568562
val df1 = DataFrame.readResultSet(rs, H2(MySql), 2)
569-
assertCustomerDataWithLimitTwo(df1)
563+
assertCustomerData(df1, 2)
570564

571565
rs.beforeFirst()
572566

573567
val df2 = DataFrame.readResultSet(rs, connection, 2)
574-
assertCustomerDataWithLimitTwo(df2)
568+
assertCustomerData(df2, 2)
575569
}
576570
}
577571
}
@@ -870,7 +864,7 @@ class JdbcTest {
870864
assertCustomerSalesData(df)
871865

872866
val df1 = DataFrame.readSqlQuery(connection, CUSTOMER_SALES_QUERY, 1)
873-
assertCustomerSalesDataWithLimit(df1)
867+
assertCustomerSalesData(df1, 1)
874868

875869
val dataSchema = DataFrameSchema.readSqlQuery(connection, CUSTOMER_SALES_QUERY)
876870
assertCustomerSalesSchema(dataSchema)
@@ -880,7 +874,7 @@ class JdbcTest {
880874
assertCustomerSalesData(df2)
881875

882876
val df3 = DataFrame.readSqlQuery(dbConfig, CUSTOMER_SALES_QUERY, 1)
883-
assertCustomerSalesDataWithLimit(df3)
877+
assertCustomerSalesData(df3, 1)
884878

885879
val dataSchema1 = DataFrameSchema.readSqlQuery(dbConfig, CUSTOMER_SALES_QUERY)
886880
assertCustomerSalesSchema(dataSchema1)
@@ -892,7 +886,7 @@ class JdbcTest {
892886
assertCustomerSalesData(df)
893887

894888
val df1 = connection.readDataFrame(CUSTOMER_SALES_QUERY, 1)
895-
assertCustomerSalesDataWithLimit(df1)
889+
assertCustomerSalesData(df1, 1)
896890

897891
val dataSchema = connection.readDataFrameSchema(CUSTOMER_SALES_QUERY)
898892
assertCustomerSalesSchema(dataSchema)
@@ -902,7 +896,7 @@ class JdbcTest {
902896
assertCustomerSalesData(df2)
903897

904898
val df3 = dbConfig.readDataFrame(CUSTOMER_SALES_QUERY, 1)
905-
assertCustomerSalesDataWithLimit(df3)
899+
assertCustomerSalesData(df3, 1)
906900

907901
val dataSchema1 = dbConfig.readDataFrameSchema(CUSTOMER_SALES_QUERY)
908902
assertCustomerSalesSchema(dataSchema1)
@@ -1052,7 +1046,7 @@ class JdbcTest {
10521046
assertCustomerData(df)
10531047

10541048
val df1 = DataFrame.readSqlTable(dataSource, tableName, 1)
1055-
assertCustomerDataWithLimit(df1)
1049+
assertCustomerData(df1, 1)
10561050

10571051
val dataSchema = DataFrameSchema.readSqlTable(dataSource, tableName)
10581052
assertCustomerSchema(dataSchema)
@@ -1065,7 +1059,7 @@ class JdbcTest {
10651059
assertCustomerData(df)
10661060

10671061
val df1 = dataSource.readDataFrame(tableName, 1)
1068-
assertCustomerDataWithLimit(df1)
1062+
assertCustomerData(df1, 1)
10691063

10701064
val dataSchema = dataSource.readDataFrameSchema(tableName)
10711065
assertCustomerSchema(dataSchema)
@@ -1077,7 +1071,7 @@ class JdbcTest {
10771071
assertCustomerSalesData(df)
10781072

10791073
val df1 = DataFrame.readSqlQuery(dataSource, CUSTOMER_SALES_QUERY, 1)
1080-
assertCustomerSalesDataWithLimit(df1)
1074+
assertCustomerSalesData(df1, 1)
10811075

10821076
val dataSchema = DataFrameSchema.readSqlQuery(dataSource, CUSTOMER_SALES_QUERY)
10831077
assertCustomerSalesSchema(dataSchema)
@@ -1089,7 +1083,7 @@ class JdbcTest {
10891083
assertCustomerSalesData(df)
10901084

10911085
val df1 = dataSource.readDataFrame(CUSTOMER_SALES_QUERY, 1)
1092-
assertCustomerSalesDataWithLimit(df1)
1086+
assertCustomerSalesData(df1, 1)
10931087

10941088
val dataSchema = dataSource.readDataFrameSchema(CUSTOMER_SALES_QUERY)
10951089
assertCustomerSalesSchema(dataSchema)
@@ -1146,7 +1140,7 @@ class JdbcTest {
11461140

11471141
repeat(10) {
11481142
val df = DataFrame.readSqlTable(dataSource, tableName, 2)
1149-
assertCustomerDataWithLimitTwo(df)
1143+
assertCustomerData(df, 2)
11501144
}
11511145
}
11521146

0 commit comments

Comments
 (0)