Skip to content

Commit eb5dc74

Browse files
mojodnadongjoon-hyun
authored andcommitted
[SPARK-28097][SQL] Map ByteType to SMALLINT for PostgresDialect
## What changes were proposed in this pull request? PostgreSQL doesn't have `TINYINT`, which would map directly, but `SMALLINT`s are sufficient for uni-directional translation. A side-effect of this fix is that `AggregatedDialect` is now usable with multiple dialects targeting `jdbc:postgresql`, as `PostgresDialect.getJDBCType` no longer throws (for which reason backporting this fix would be lovely): https://github.com/apache/spark/blob/1217996f1574f758d8cccc1c4e3846452d24b35b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala#L42 `dialects.flatMap` currently throws on the first attempt to get a JDBC type preventing subsequent dialects in the chain from providing an alternative. ## How was this patch tested? Unit tests. Closes apache#24845 from mojodna/postgres-byte-type-mapping. Authored-by: Seth Fitzsimmons <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 28774cd commit eb5dc74

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,4 +206,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
206206
""".stripMargin.replaceAll("\n", " "))
207207
assert(sql("select c1, c3 from queryOption").collect.toSet == expectedResult)
208208
}
209+
210+
test("write byte as smallint") {
211+
sqlContext.createDataFrame(Seq((1.toByte, 2.toShort)))
212+
.write.jdbc(jdbcUrl, "byte_to_smallint_test", new Properties)
213+
val df = sqlContext.read.jdbc(jdbcUrl, "byte_to_smallint_test", new Properties)
214+
val schema = df.schema
215+
assert(schema.head.dataType == ShortType)
216+
assert(schema(1).dataType == ShortType)
217+
val rows = df.collect()
218+
assert(rows.length === 1)
219+
assert(rows(0).getShort(0) === 1)
220+
assert(rows(0).getShort(1) === 2)
221+
}
209222
}

sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,13 @@ private object PostgresDialect extends JdbcDialect {
7373
case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
7474
case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
7575
case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
76-
case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT))
76+
case ShortType | ByteType => Some(JdbcType("SMALLINT", Types.SMALLINT))
7777
case t: DecimalType => Some(
7878
JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
7979
case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
8080
getJDBCType(et).map(_.databaseTypeDefinition)
8181
.orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
8282
.map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
83-
case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt");
8483
case _ => None
8584
}
8685

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -857,10 +857,7 @@ class JDBCSuite extends QueryTest
857857
Some(ArrayType(DecimalType.SYSTEM_DEFAULT)))
858858
assert(Postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4")
859859
assert(Postgres.getJDBCType(DoubleType).map(_.databaseTypeDefinition).get == "FLOAT8")
860-
val errMsg = intercept[IllegalArgumentException] {
861-
Postgres.getJDBCType(ByteType)
862-
}
863-
assert(errMsg.getMessage contains "Unsupported type in postgresql: ByteType")
860+
assert(Postgres.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT")
864861
}
865862

866863
test("DerbyDialect jdbc type mapping") {

0 commit comments

Comments
 (0)