Skip to content

Commit fdcd140

Browse files
szehon-hocloud-fan
authored andcommitted
[SPARK-53629][SQL] Implement type widening for MERGE INTO WITH SCHEMA EVOLUTION
### What changes were proposed in this pull request? MERGE INTO WITH SCHEMA EVOLUTION already support adding new column, and also some type widening (if structs are missing some fields) It should support type widening for primitive data types. Spark will call the V2DataSource TableCatalog to alter the schema, so the V2DataSource can decide whether it is acceptable or not. This change also fixes InMemoryDataSource to support this case ### Why are the changes needed? Support more use case for MERGE INTO WITH SCHEMA EVOLUTION. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Add unit test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52377 from szehon-ho/merge_type_evolution. Authored-by: Szehon Ho <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent cf30da2 commit fdcd140

File tree

3 files changed

+440
-9
lines changed

3 files changed

+440
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
4141
import org.apache.spark.sql.errors.QueryExecutionErrors
4242
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
4343
import org.apache.spark.sql.internal.SQLConf
44-
import org.apache.spark.sql.types.{ArrayType, BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType}
44+
import org.apache.spark.sql.types.{ArrayType, AtomicType, BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType}
4545
import org.apache.spark.util.ArrayImplicits._
4646
import org.apache.spark.util.Utils
4747

@@ -967,12 +967,15 @@ object MergeIntoTable {
967967
schemaChanges(currentElementType, updateElementType,
968968
originalTarget, originalSource, fieldPath ++ Seq("value"))
969969

970+
case (currentType: AtomicType, newType: AtomicType) if currentType != newType =>
971+
Array(TableChange.updateColumnType(fieldPath, newType))
972+
970973
case (currentType, newType) if currentType == newType =>
971974
// No change needed
972975
Array.empty[TableChange]
973976

974977
case _ =>
975-
// For now do not support type widening
978+
// Do not support change between atomic and complex types for now
976979
throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(
977980
originalTarget, originalSource, null)
978981
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ import scala.collection.mutable.ListBuffer
2828
import scala.jdk.CollectionConverters._
2929

3030
import org.apache.spark.sql.catalyst.InternalRow
31-
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, MetadataStructFieldWithLogicalName}
31+
import org.apache.spark.sql.catalyst.expressions.{Cast, EvalMode, GenericInternalRow, JoinedRow, Literal, MetadataStructFieldWithLogicalName}
3232
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData, MapData, ResolveDefaultColumns}
3333
import org.apache.spark.sql.connector.catalog.constraints.Constraint
3434
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
3535
import org.apache.spark.sql.connector.expressions._
36+
import org.apache.spark.sql.connector.expressions.{Literal => V2Literal}
3637
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric}
3738
import org.apache.spark.sql.connector.read._
3839
import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram, HistogramBin}
@@ -146,7 +147,7 @@ abstract class InMemoryBaseTable(
146147
case _: BucketTransform =>
147148
case _: SortedBucketTransform =>
148149
case _: ClusterByTransform =>
149-
case NamedTransform("truncate", Seq(_: NamedReference, _: Literal[_])) =>
150+
case NamedTransform("truncate", Seq(_: NamedReference, _: V2Literal[_])) =>
150151
case t if !allowUnsupportedTransforms =>
151152
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
152153
}
@@ -244,7 +245,7 @@ abstract class InMemoryBaseTable(
244245
var dataTypeHashCode = 0
245246
valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
246247
((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
247-
case NamedTransform("truncate", Seq(ref: NamedReference, length: Literal[_])) =>
248+
case NamedTransform("truncate", Seq(ref: NamedReference, length: V2Literal[_])) =>
248249
extractor(ref.fieldNames, cleanedSchema, row) match {
249250
case (str: UTF8String, StringType) =>
250251
str.substring(0, length.value.asInstanceOf[Int])
@@ -910,7 +911,7 @@ private class BufferedRowsReader(
910911
arrayData: ArrayData,
911912
readType: DataType,
912913
writeType: DataType): ArrayData = {
913-
val elements = arrayData.toArray[Any](readType)
914+
val elements = arrayData.toArray[Any](writeType)
914915
val convertedElements = extractCollection(elements, readType, writeType)
915916
new GenericArrayData(convertedElements)
916917
}
@@ -921,8 +922,8 @@ private class BufferedRowsReader(
921922
readValueType: DataType,
922923
writeKeyType: DataType,
923924
writeValueType: DataType): MapData = {
924-
val keys = mapData.keyArray().toArray[Any](readKeyType)
925-
val values = mapData.valueArray().toArray[Any](readValueType)
925+
val keys = mapData.keyArray().toArray[Any](writeKeyType)
926+
val values = mapData.valueArray().toArray[Any](writeValueType)
926927

927928
val convertedKeys = extractCollection(keys, readKeyType, writeKeyType)
928929
val convertedValues = extractCollection(values, readValueType, writeValueType)
@@ -962,9 +963,20 @@ private class BufferedRowsReader(
962963
wKeyType, wValueType)
963964
}
964965
}
966+
case (readType: AtomicType, writeType: AtomicType) if readType != writeType =>
967+
elements.map { elem =>
968+
if (elem == null) {
969+
null
970+
} else {
971+
castElement(elem, readType, writeType)
972+
}
973+
}
965974
case (_, _) => elements
966975
}
967976
}
977+
978+
private def castElement(elem: Any, toType: DataType, fromType: DataType): Any =
979+
Cast(Literal(elem, fromType), toType, None, EvalMode.TRY).eval(null)
968980
}
969981

970982
private class BufferedRowsWriterFactory(schema: StructType)

0 commit comments

Comments
 (0)