Skip to content

Commit 6d7ebf2

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-22165][SQL] Fixes type conflicts between double, long, decimals, dates and timestamps in partition column
## What changes were proposed in this pull request? This PR proposes to add a rule that re-uses `TypeCoercion.findWiderCommonType` when resolving type conflicts in partition values. Currently, this uses numeric precedence-like comparison; therefore, it looks introducing failures for type conflicts between timestamps, dates and decimals, please see: ```scala private val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) ... literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) ``` The codes below: ```scala val df = Seq((1, "2015-01-01"), (2, "2016-01-01 00:00:00")).toDF("i", "ts") df.write.format("parquet").partitionBy("ts").save("/tmp/foo") spark.read.load("/tmp/foo").printSchema() val df = Seq((1, "1"), (2, "1" * 30)).toDF("i", "decimal") df.write.format("parquet").partitionBy("decimal").save("/tmp/bar") spark.read.load("/tmp/bar").printSchema() ``` produces output as below: **Before** ``` root |-- i: integer (nullable = true) |-- ts: date (nullable = true) root |-- i: integer (nullable = true) |-- decimal: integer (nullable = true) ``` **After** ``` root |-- i: integer (nullable = true) |-- ts: timestamp (nullable = true) root |-- i: integer (nullable = true) |-- decimal: decimal(30,0) (nullable = true) ``` ### Type coercion table: This PR proposes the type conflict resolusion as below: **Before** |InputA \ InputB|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |------------------------|----------|----------|----------|----------|----------|----------|----------|----------| |**`NullType`**|`StringType`|`IntegerType`|`LongType`|`StringType`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`IntegerType`**|`IntegerType`|`IntegerType`|`LongType`|`IntegerType`|`DoubleType`|`IntegerType`|`IntegerType`|`StringType`| |**`LongType`**|`LongType`|`LongType`|`LongType`|`LongType`|`DoubleType`|`LongType`|`LongType`|`StringType`| |**`DecimalType(38,0)`**|`StringType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DecimalType(38,0)`|`DecimalType(38,0)`|`StringType`| |**`DoubleType`**|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`StringType`| |**`DateType`**|`StringType`|`IntegerType`|`LongType`|`DateType`|`DoubleType`|`DateType`|`DateType`|`StringType`| |**`TimestampType`**|`StringType`|`IntegerType`|`LongType`|`TimestampType`|`DoubleType`|`TimestampType`|`TimestampType`|`StringType`| |**`StringType`**|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`| **After** |InputA \ InputB|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |------------------------|----------|----------|----------|----------|----------|----------|----------|----------| |**`NullType`**|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |**`IntegerType`**|`IntegerType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`LongType`**|`LongType`|`LongType`|`LongType`|`DecimalType(38,0)`|`StringType`|`StringType`|`StringType`|`StringType`| |**`DecimalType(38,0)`**|`DecimalType(38,0)`|`DecimalType(38,0)`|`DecimalType(38,0)`|`DecimalType(38,0)`|`StringType`|`StringType`|`StringType`|`StringType`| |**`DoubleType`**|`DoubleType`|`DoubleType`|`StringType`|`StringType`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`DateType`**|`DateType`|`StringType`|`StringType`|`StringType`|`StringType`|`DateType`|`TimestampType`|`StringType`| |**`TimestampType`**|`TimestampType`|`StringType`|`StringType`|`StringType`|`StringType`|`TimestampType`|`TimestampType`|`StringType`| |**`StringType`**|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`| This was produced by: ```scala test("Print out chart") { val supportedTypes: Seq[DataType] = Seq( NullType, IntegerType, LongType, DecimalType(38, 0), DoubleType, DateType, TimestampType, StringType) // Old type conflict resolution: val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) def oldResolveTypeConflicts(dataTypes: Seq[DataType]): DataType = { val topType = dataTypes.maxBy(upCastingOrder.indexOf(_)) if (topType == NullType) StringType else topType } println(s"|InputA \\ InputB|${supportedTypes.map(dt => s"`${dt.toString}`").mkString("|")}|") println(s"|------------------------|${supportedTypes.map(_ => "----------").mkString("|")}|") supportedTypes.foreach { inputA => val types = supportedTypes.map(inputB => oldResolveTypeConflicts(Seq(inputA, inputB))) println(s"|**`$inputA`**|${types.map(dt => s"`${dt.toString}`").mkString("|")}|") } // New type conflict resolution: def newResolveTypeConflicts(dataTypes: Seq[DataType]): DataType = { dataTypes.fold[DataType](NullType)(findWiderTypeForPartitionColumn) } println(s"|InputA \\ InputB|${supportedTypes.map(dt => s"`${dt.toString}`").mkString("|")}|") println(s"|------------------------|${supportedTypes.map(_ => "----------").mkString("|")}|") supportedTypes.foreach { inputA => val types = supportedTypes.map(inputB => newResolveTypeConflicts(Seq(inputA, inputB))) println(s"|**`$inputA`**|${types.map(dt => s"`${dt.toString}`").mkString("|")}|") } } ``` ## How was this patch tested? Unit tests added in `ParquetPartitionDiscoverySuite`. Author: hyukjinkwon <[email protected]> Closes #19389 from HyukjinKwon/partition-type-coercion.
1 parent 2d868d9 commit 6d7ebf2

File tree

4 files changed

+235
-23
lines changed

4 files changed

+235
-23
lines changed

docs/sql-programming-guide.md

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,145 @@ options.
15771577

15781578
- Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
15791579
- The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles.
1580+
- Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below:
1581+
1582+
<table class="table">
1583+
<tr>
1584+
<th>
1585+
<b>InputA \ InputB</b>
1586+
</th>
1587+
<th>
1588+
<b>NullType</b>
1589+
</th>
1590+
<th>
1591+
<b>IntegerType</b>
1592+
</th>
1593+
<th>
1594+
<b>LongType</b>
1595+
</th>
1596+
<th>
1597+
<b>DecimalType(38,0)*</b>
1598+
</th>
1599+
<th>
1600+
<b>DoubleType</b>
1601+
</th>
1602+
<th>
1603+
<b>DateType</b>
1604+
</th>
1605+
<th>
1606+
<b>TimestampType</b>
1607+
</th>
1608+
<th>
1609+
<b>StringType</b>
1610+
</th>
1611+
</tr>
1612+
<tr>
1613+
<td>
1614+
<b>NullType</b>
1615+
</td>
1616+
<td>NullType</td>
1617+
<td>IntegerType</td>
1618+
<td>LongType</td>
1619+
<td>DecimalType(38,0)</td>
1620+
<td>DoubleType</td>
1621+
<td>DateType</td>
1622+
<td>TimestampType</td>
1623+
<td>StringType</td>
1624+
</tr>
1625+
<tr>
1626+
<td>
1627+
<b>IntegerType</b>
1628+
</td>
1629+
<td>IntegerType</td>
1630+
<td>IntegerType</td>
1631+
<td>LongType</td>
1632+
<td>DecimalType(38,0)</td>
1633+
<td>DoubleType</td>
1634+
<td>StringType</td>
1635+
<td>StringType</td>
1636+
<td>StringType</td>
1637+
</tr>
1638+
<tr>
1639+
<td>
1640+
<b>LongType</b>
1641+
</td>
1642+
<td>LongType</td>
1643+
<td>LongType</td>
1644+
<td>LongType</td>
1645+
<td>DecimalType(38,0)</td>
1646+
<td>StringType</td>
1647+
<td>StringType</td>
1648+
<td>StringType</td>
1649+
<td>StringType</td>
1650+
</tr>
1651+
<tr>
1652+
<td>
1653+
<b>DecimalType(38,0)*</b>
1654+
</td>
1655+
<td>DecimalType(38,0)</td>
1656+
<td>DecimalType(38,0)</td>
1657+
<td>DecimalType(38,0)</td>
1658+
<td>DecimalType(38,0)</td>
1659+
<td>StringType</td>
1660+
<td>StringType</td>
1661+
<td>StringType</td>
1662+
<td>StringType</td>
1663+
</tr>
1664+
<tr>
1665+
<td>
1666+
<b>DoubleType</b>
1667+
</td>
1668+
<td>DoubleType</td>
1669+
<td>DoubleType</td>
1670+
<td>StringType</td>
1671+
<td>StringType</td>
1672+
<td>DoubleType</td>
1673+
<td>StringType</td>
1674+
<td>StringType</td>
1675+
<td>StringType</td>
1676+
</tr>
1677+
<tr>
1678+
<td>
1679+
<b>DateType</b>
1680+
</td>
1681+
<td>DateType</td>
1682+
<td>StringType</td>
1683+
<td>StringType</td>
1684+
<td>StringType</td>
1685+
<td>StringType</td>
1686+
<td>DateType</td>
1687+
<td>TimestampType</td>
1688+
<td>StringType</td>
1689+
</tr>
1690+
<tr>
1691+
<td>
1692+
<b>TimestampType</b>
1693+
</td>
1694+
<td>TimestampType</td>
1695+
<td>StringType</td>
1696+
<td>StringType</td>
1697+
<td>StringType</td>
1698+
<td>StringType</td>
1699+
<td>TimestampType</td>
1700+
<td>TimestampType</td>
1701+
<td>StringType</td>
1702+
</tr>
1703+
<tr>
1704+
<td>
1705+
<b>StringType</b>
1706+
</td>
1707+
<td>StringType</td>
1708+
<td>StringType</td>
1709+
<td>StringType</td>
1710+
<td>StringType</td>
1711+
<td>StringType</td>
1712+
<td>StringType</td>
1713+
<td>StringType</td>
1714+
<td>StringType</td>
1715+
</tr>
1716+
</table>
1717+
1718+
Note that, for <b>DecimalType(38,0)*</b>, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type.
15801719

15811720
## Upgrading From Spark SQL 2.1 to 2.2
15821721

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ object TypeCoercion {
155155
* i.e. the main difference with [[findTightestCommonType]] is that here we allow some
156156
* loss of precision when widening decimal and double, and promotion to string.
157157
*/
158-
private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
158+
def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
159159
findTightestCommonType(t1, t2)
160160
.orElse(findWiderTypeForDecimal(t1, t2))
161161
.orElse(stringPromotion(t1, t2))

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path
2828

2929
import org.apache.spark.sql.AnalysisException
3030
import org.apache.spark.sql.catalyst.InternalRow
31-
import org.apache.spark.sql.catalyst.analysis.Resolver
31+
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
3232
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3333
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
3434
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -309,13 +309,8 @@ object PartitioningUtils {
309309
}
310310

311311
/**
312-
* Resolves possible type conflicts between partitions by up-casting "lower" types. The up-
313-
* casting order is:
314-
* {{{
315-
* NullType ->
316-
* IntegerType -> LongType ->
317-
* DoubleType -> StringType
318-
* }}}
312+
* Resolves possible type conflicts between partitions by up-casting "lower" types using
313+
* [[findWiderTypeForPartitionColumn]].
319314
*/
320315
def resolvePartitions(
321316
pathsWithPartitionValues: Seq[(Path, PartitionValues)],
@@ -372,11 +367,31 @@ object PartitioningUtils {
372367
suspiciousPaths.map("\t" + _).mkString("\n", "\n", "")
373368
}
374369

370+
// scalastyle:off line.size.limit
375371
/**
376-
* Converts a string to a [[Literal]] with automatic type inference. Currently only supports
377-
* [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]]
372+
* Converts a string to a [[Literal]] with automatic type inference. Currently only supports
373+
* [[NullType]], [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]]
378374
* [[TimestampType]], and [[StringType]].
375+
*
376+
* When resolving conflicts, it follows the table below:
377+
*
378+
* +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
379+
* | InputA \ InputB | NullType | IntegerType | LongType | DecimalType(38,0)* | DoubleType | DateType | TimestampType | StringType |
380+
* +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
381+
* | NullType | NullType | IntegerType | LongType | DecimalType(38,0) | DoubleType | DateType | TimestampType | StringType |
382+
* | IntegerType | IntegerType | IntegerType | LongType | DecimalType(38,0) | DoubleType | StringType | StringType | StringType |
383+
* | LongType | LongType | LongType | LongType | DecimalType(38,0) | StringType | StringType | StringType | StringType |
384+
* | DecimalType(38,0)* | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | StringType | StringType | StringType | StringType |
385+
* | DoubleType | DoubleType | DoubleType | StringType | StringType | DoubleType | StringType | StringType | StringType |
386+
* | DateType | DateType | StringType | StringType | StringType | StringType | DateType | TimestampType | StringType |
387+
* | TimestampType | TimestampType | StringType | StringType | StringType | StringType | TimestampType | TimestampType | StringType |
388+
* | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType |
389+
* +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
390+
* Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other
391+
* combinations of scales and precisions because currently we only infer decimal type like
392+
* `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type.
379393
*/
394+
// scalastyle:on line.size.limit
380395
private[datasources] def inferPartitionColumnValue(
381396
raw: String,
382397
typeInference: Boolean,
@@ -427,9 +442,6 @@ object PartitioningUtils {
427442
}
428443
}
429444

430-
private val upCastingOrder: Seq[DataType] =
431-
Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType)
432-
433445
def validatePartitionColumn(
434446
schema: StructType,
435447
partitionColumns: Seq[String],
@@ -468,18 +480,26 @@ object PartitioningUtils {
468480
}
469481

470482
/**
471-
* Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower"
472-
* types.
483+
* Given a collection of [[Literal]]s, resolves possible type conflicts by
484+
* [[findWiderTypeForPartitionColumn]].
473485
*/
474486
private def resolveTypeConflicts(literals: Seq[Literal], timeZone: TimeZone): Seq[Literal] = {
475-
val desiredType = {
476-
val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_))
477-
// Falls back to string if all values of this column are null or empty string
478-
if (topType == NullType) StringType else topType
479-
}
487+
val litTypes = literals.map(_.dataType)
488+
val desiredType = litTypes.reduce(findWiderTypeForPartitionColumn)
480489

481490
literals.map { case l @ Literal(_, dataType) =>
482491
Literal.create(Cast(l, desiredType, Some(timeZone.getID)).eval(), desiredType)
483492
}
484493
}
494+
495+
/**
496+
* Type widening rule for partition column types. It is similar to
497+
* [[TypeCoercion.findWiderTypeForTwo]] but the main difference is that here we disallow
498+
* precision loss when widening double/long and decimal, and fall back to string.
499+
*/
500+
private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = {
501+
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType
502+
case (DoubleType, LongType) | (LongType, DoubleType) => StringType
503+
case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType)
504+
}
485505
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
249249
true,
250250
rootPaths,
251251
timeZoneId)
252+
assert(actualSpec.partitionColumns === spec.partitionColumns)
253+
assert(actualSpec.partitions.length === spec.partitions.length)
254+
actualSpec.partitions.zip(spec.partitions).foreach { case (actual, expected) =>
255+
assert(actual === expected)
256+
}
252257
assert(actualSpec === spec)
253258
}
254259

@@ -314,7 +319,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
314319
PartitionSpec(
315320
StructType(Seq(
316321
StructField("a", DoubleType),
317-
StructField("b", StringType))),
322+
StructField("b", NullType))),
318323
Seq(
319324
Partition(InternalRow(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"),
320325
Partition(InternalRow(10.5, null),
@@ -324,6 +329,32 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
324329
s"hdfs://host:9000/path1",
325330
s"hdfs://host:9000/path2"),
326331
PartitionSpec.emptySpec)
332+
333+
// The cases below check the resolution for type conflicts.
334+
val t1 = Timestamp.valueOf("2014-01-01 00:00:00.0").getTime * 1000
335+
val t2 = Timestamp.valueOf("2014-01-01 00:01:00.0").getTime * 1000
336+
// Values in column 'a' are inferred as null, date and timestamp each, and timestamp is set
337+
// as a common type.
338+
// Values in column 'b' are inferred as integer, decimal(22, 0) and null, and decimal(22, 0)
339+
// is set as a common type.
340+
check(Seq(
341+
s"hdfs://host:9000/path/a=$defaultPartitionName/b=0",
342+
s"hdfs://host:9000/path/a=2014-01-01/b=${Long.MaxValue}111",
343+
s"hdfs://host:9000/path/a=2014-01-01 00%3A01%3A00.0/b=$defaultPartitionName"),
344+
PartitionSpec(
345+
StructType(Seq(
346+
StructField("a", TimestampType),
347+
StructField("b", DecimalType(22, 0)))),
348+
Seq(
349+
Partition(
350+
InternalRow(null, Decimal(0)),
351+
s"hdfs://host:9000/path/a=$defaultPartitionName/b=0"),
352+
Partition(
353+
InternalRow(t1, Decimal(s"${Long.MaxValue}111")),
354+
s"hdfs://host:9000/path/a=2014-01-01/b=${Long.MaxValue}111"),
355+
Partition(
356+
InternalRow(t2, null),
357+
s"hdfs://host:9000/path/a=2014-01-01 00%3A01%3A00.0/b=$defaultPartitionName"))))
327358
}
328359

329360
test("parse partitions with type inference disabled") {
@@ -395,7 +426,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
395426
PartitionSpec(
396427
StructType(Seq(
397428
StructField("a", StringType),
398-
StructField("b", StringType))),
429+
StructField("b", NullType))),
399430
Seq(
400431
Partition(InternalRow(UTF8String.fromString("10"), null),
401432
s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"),
@@ -1067,4 +1098,26 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
10671098
checkAnswer(spark.read.load(path.getAbsolutePath), df)
10681099
}
10691100
}
1101+
1102+
test("Resolve type conflicts - decimals, dates and timestamps in partition column") {
1103+
withTempPath { path =>
1104+
val df = Seq((1, "2014-01-01"), (2, "2016-01-01"), (3, "2015-01-01 00:01:00")).toDF("i", "ts")
1105+
df.write.format("parquet").partitionBy("ts").save(path.getAbsolutePath)
1106+
checkAnswer(
1107+
spark.read.load(path.getAbsolutePath),
1108+
Row(1, Timestamp.valueOf("2014-01-01 00:00:00")) ::
1109+
Row(2, Timestamp.valueOf("2016-01-01 00:00:00")) ::
1110+
Row(3, Timestamp.valueOf("2015-01-01 00:01:00")) :: Nil)
1111+
}
1112+
1113+
withTempPath { path =>
1114+
val df = Seq((1, "1"), (2, "3"), (3, "2" * 30)).toDF("i", "decimal")
1115+
df.write.format("parquet").partitionBy("decimal").save(path.getAbsolutePath)
1116+
checkAnswer(
1117+
spark.read.load(path.getAbsolutePath),
1118+
Row(1, BigDecimal("1")) ::
1119+
Row(2, BigDecimal("3")) ::
1120+
Row(3, BigDecimal("2" * 30)) :: Nil)
1121+
}
1122+
}
10701123
}

0 commit comments

Comments
 (0)