diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java new file mode 100644 index 000000000000..56115da0d804 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.InternalArray; +import org.apache.paimon.data.InternalMap; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.data.variant.Variant; +import org.apache.paimon.spark.util.SparkRowUtils$; +import org.apache.paimon.spark.util.shim.TypeUtils$; +import org.apache.paimon.types.RowKind; + +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.paimon.shims.SparkShimLoader; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; + +import java.math.BigDecimal; + +/** Wrapper to fetch value from the spark internal row. */ +public class SparkInternalRowWrapper implements InternalRow { + + private org.apache.spark.sql.catalyst.InternalRow internalRow; + private final int length; + private final int rowKindIdx; + private final StructType structType; + + public SparkInternalRowWrapper( + org.apache.spark.sql.catalyst.InternalRow internalRow, + int rowKindIdx, + StructType structType, + int length) { + this.internalRow = internalRow; + this.rowKindIdx = rowKindIdx; + this.length = length; + this.structType = structType; + } + + public SparkInternalRowWrapper(int rowKindIdx, StructType structType, int length) { + this.rowKindIdx = rowKindIdx; + this.length = length; + this.structType = structType; + } + + public SparkInternalRowWrapper replace(org.apache.spark.sql.catalyst.InternalRow internalRow) { + this.internalRow = internalRow; + return this; + } + + @Override + public int getFieldCount() { + return length; + } + + @Override + public RowKind getRowKind() { + return SparkRowUtils$.MODULE$.getRowKind(internalRow, rowKindIdx); + } + + @Override + public void setRowKind(RowKind kind) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int pos) { + return internalRow.isNullAt(pos); + } + + @Override + public boolean getBoolean(int pos) { + return internalRow.getBoolean(pos); + } + + @Override + public byte getByte(int pos) { + return internalRow.getByte(pos); + } + + @Override + public short getShort(int pos) { + return internalRow.getShort(pos); + } + + @Override + public int getInt(int pos) { + return internalRow.getInt(pos); + } + + @Override + public long getLong(int pos) { + return internalRow.getLong(pos); + } + + @Override + public float getFloat(int pos) { + return internalRow.getFloat(pos); + } + + @Override + public double getDouble(int pos) { + return internalRow.getDouble(pos); + } + + @Override + public BinaryString getString(int pos) { + return BinaryString.fromBytes(internalRow.getUTF8String(pos).getBytes()); + } + + @Override + public Decimal getDecimal(int pos, int precision, int scale) { + org.apache.spark.sql.types.Decimal decimal = internalRow.getDecimal(pos, precision, scale); + BigDecimal bigDecimal = decimal.toJavaBigDecimal(); + return Decimal.fromBigDecimal(bigDecimal, precision, scale); + } + + @Override + public Timestamp getTimestamp(int pos, int precision) { + return convertToTimestamp(structType.fields()[pos].dataType(), internalRow.getLong(pos)); + } + + @Override + public byte[] getBinary(int pos) { + return internalRow.getBinary(pos); + } + + @Override + public Variant getVariant(int pos) { + return SparkShimLoader.getSparkShim().toPaimonVariant(internalRow, pos); + } + + @Override + public InternalArray getArray(int pos) { + return new SparkInternalArray( + internalRow.getArray(pos), + ((ArrayType) (structType.fields()[pos].dataType())).elementType()); + } + + @Override + public InternalMap getMap(int pos) { + MapType mapType = (MapType) structType.fields()[pos].dataType(); + return new SparkInternalMap( + internalRow.getMap(pos), mapType.keyType(), mapType.valueType()); + } + + @Override + public InternalRow getRow(int pos, int numFields) { + return new SparkInternalRowWrapper( + internalRow.getStruct(pos, numFields), + -1, + (StructType) structType.fields()[pos].dataType(), + numFields); + } + + private static Timestamp convertToTimestamp(DataType dataType, long micros) { + if (dataType instanceof TimestampType) { + if (TypeUtils$.MODULE$.treatPaimonTimestampTypeAsSparkTimestampType()) { + return Timestamp.fromSQLTimestamp(DateTimeUtils.toJavaTimestamp(micros)); + } else { + return Timestamp.fromMicros(micros); + } + } else if (dataType instanceof TimestampNTZType) { + return Timestamp.fromMicros(micros); + } else { + throw new UnsupportedOperationException("Unsupported data type:" + dataType); + } + } + + /** adapt to spark internal array. */ + public static class SparkInternalArray implements InternalArray { + + private final ArrayData arrayData; + private final DataType elementType; + + public SparkInternalArray(ArrayData arrayData, DataType elementType) { + this.arrayData = arrayData; + this.elementType = elementType; + } + + @Override + public int size() { + return arrayData.numElements(); + } + + @Override + public boolean[] toBooleanArray() { + return arrayData.toBooleanArray(); + } + + @Override + public byte[] toByteArray() { + return arrayData.toByteArray(); + } + + @Override + public short[] toShortArray() { + return arrayData.toShortArray(); + } + + @Override + public int[] toIntArray() { + return arrayData.toIntArray(); + } + + @Override + public long[] toLongArray() { + return arrayData.toLongArray(); + } + + @Override + public float[] toFloatArray() { + return arrayData.toFloatArray(); + } + + @Override + public double[] toDoubleArray() { + return arrayData.toDoubleArray(); + } + + @Override + public boolean isNullAt(int pos) { + return arrayData.isNullAt(pos); + } + + @Override + public boolean getBoolean(int pos) { + return arrayData.getBoolean(pos); + } + + @Override + public byte getByte(int pos) { + return arrayData.getByte(pos); + } + + @Override + public short getShort(int pos) { + return arrayData.getShort(pos); + } + + @Override + public int getInt(int pos) { + return arrayData.getInt(pos); + } + + @Override + public long getLong(int pos) { + return arrayData.getLong(pos); + } + + @Override + public float getFloat(int pos) { + return arrayData.getFloat(pos); + } + + @Override + public double getDouble(int pos) { + return arrayData.getDouble(pos); + } + + @Override + public BinaryString getString(int pos) { + return BinaryString.fromBytes(arrayData.getUTF8String(pos).getBytes()); + } + + @Override + public Decimal getDecimal(int pos, int precision, int scale) { + org.apache.spark.sql.types.Decimal decimal = + arrayData.getDecimal(pos, precision, scale); + return Decimal.fromBigDecimal(decimal.toJavaBigDecimal(), precision, scale); + } + + @Override + public Timestamp getTimestamp(int pos, int precision) { + return convertToTimestamp(elementType, arrayData.getLong(pos)); + } + + @Override + public byte[] getBinary(int pos) { + return arrayData.getBinary(pos); + } + + @Override + public Variant getVariant(int pos) { + return SparkShimLoader.getSparkShim().toPaimonVariant(arrayData, pos); + } + + @Override + public InternalArray getArray(int pos) { + return new SparkInternalArray( + arrayData.getArray(pos), ((ArrayType) elementType).elementType()); + } + + @Override + public InternalMap getMap(int pos) { + MapType mapType = (MapType) elementType; + return new SparkInternalMap( + arrayData.getMap(pos), mapType.keyType(), mapType.valueType()); + } + + @Override + public InternalRow getRow(int pos, int numFields) { + return new SparkInternalRowWrapper( + arrayData.getStruct(pos, numFields), -1, (StructType) elementType, numFields); + } + } + + /** adapt to spark internal map. */ + public static class SparkInternalMap implements InternalMap { + + private final MapData mapData; + private final DataType keyType; + private final DataType valueType; + + public SparkInternalMap(MapData mapData, DataType keyType, DataType valueType) { + this.mapData = mapData; + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + public int size() { + return mapData.numElements(); + } + + @Override + public InternalArray keyArray() { + return new SparkInternalArray(mapData.keyArray(), keyType); + } + + @Override + public InternalArray valueArray() { + return new SparkInternalArray(mapData.valueArray(), valueType); + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/Compatibility.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/Compatibility.scala index 751b88f585b9..5f78cda21bdc 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/Compatibility.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/Compatibility.scala @@ -18,7 +18,7 @@ package org.apache.paimon.spark.catalyst -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EvalMode, Expression} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.read.Scan diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/BucketExpression.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/BucketExpression.scala new file mode 100644 index 000000000000..cc411eb9fce3 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/BucketExpression.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.commands + +import org.apache.paimon.data.serializer.InternalRowSerializer +import org.apache.paimon.spark.SparkInternalRowWrapper +import org.apache.paimon.spark.SparkTypeUtils.toPaimonType +import org.apache.paimon.table.sink.KeyAndBucketExtractor.{bucket, bucketKeyHashCode} +import org.apache.paimon.types.{RowKind, RowType} + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow => SparkInternalRow} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} + +/** @param _children arg0: bucket number, arg1..argn bucket key */ +case class FixedBucketExpression(_children: Seq[Expression]) + extends Expression + with CodegenFallback { + + private lazy val (bucketKeyRowType: RowType, bucketKeyStructType: StructType) = { + val (originalTypes, paimonTypes) = _children.tail.map { + expr => + (StructField(expr.prettyName, expr.dataType, nullable = true), toPaimonType(expr.dataType)) + }.unzip + + ( + RowType.of(paimonTypes: _*), + StructType(originalTypes) + ) + } + + private lazy val numberBuckets = _children.head.asInstanceOf[Literal].value.asInstanceOf[Int] + private lazy val serializer = new InternalRowSerializer(bucketKeyRowType) + private lazy val wrapper = + new SparkInternalRowWrapper(-1, bucketKeyStructType, bucketKeyStructType.fields.length) + + override def nullable: Boolean = false + + override def eval(input: SparkInternalRow): Int = { + val bucketKeyValues = _children.tail.map(_.eval(input)) + bucket( + bucketKeyHashCode( + serializer.toBinaryRow(wrapper.replace(SparkInternalRow.fromSeq(bucketKeyValues)))), + numberBuckets) + } + + override def dataType: DataType = DataTypes.IntegerType + + override def children: Seq[Expression] = _children + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + copy(_children = newChildren) + } + + override def canEqual(that: Any): Boolean = false +} + +object BucketExpression { + + val FIXED_BUCKET = "fixed_bucket" + val supportedFnNames: Seq[String] = Seq(FIXED_BUCKET) + + private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) + + def getFunctionInjection(fnName: String): FunctionDescription = { + val (info, builder) = fnName match { + case FIXED_BUCKET => + FunctionRegistryBase.build[FixedBucketExpression](fnName, since = None) + case _ => + throw new Exception(s"Function $fnName isn't a supported scalar function.") + } + val ident = FunctionIdentifier(fnName) + (ident, info, builder) + } + +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala index 559fe900b431..db9800b46938 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala @@ -25,10 +25,10 @@ import org.apache.paimon.crosspartition.{IndexBootstrap, KeyPartOrRow} import org.apache.paimon.data.serializer.InternalSerializers import org.apache.paimon.deletionvectors.DeletionVector import org.apache.paimon.deletionvectors.append.AppendDeletionFileMaintainer -import org.apache.paimon.index.{BucketAssigner, PartitionIndex, SimpleHashBucketAssigner} +import org.apache.paimon.index.{BucketAssigner, SimpleHashBucketAssigner} import org.apache.paimon.io.{CompactIncrement, DataIncrement, IndexIncrement} import org.apache.paimon.manifest.{FileKind, IndexManifestEntry} -import org.apache.paimon.spark.{SparkRow, SparkTableWrite, SparkTypeUtils} +import org.apache.paimon.spark.{SparkInternalRowWrapper, SparkRow, SparkTableWrite, SparkTypeUtils} import org.apache.paimon.spark.schema.SparkSystemColumns.{BUCKET_COL, ROW_KIND_COL} import org.apache.paimon.spark.util.SparkRowUtils import org.apache.paimon.table.BucketMode._ @@ -39,8 +39,10 @@ import org.apache.paimon.utils.{InternalRowPartitionComputer, PartitionPathUtils import org.apache.spark.{Partitioner, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory import java.io.IOException @@ -58,6 +60,11 @@ case class PaimonSparkWriter(table: FileStoreTable) { private lazy val log = LoggerFactory.getLogger(classOf[PaimonSparkWriter]) + private val extensionKey = "spark.sql.extensions" + + private val paimonSparkExtension = + "org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions" + @transient private lazy val serializer = new CommitMessageSerializer val writeBuilder: BatchWriteBuilder = table.newBatchWriteBuilder() @@ -70,7 +77,16 @@ case class PaimonSparkWriter(table: FileStoreTable) { val sparkSession = data.sparkSession import sparkSession.implicits._ + def paimonExtensionEnabled: Boolean = { + val extensions = sparkSession.sessionState.conf.getConfString(extensionKey) + if (extensions != null && extensions.contains(paimonSparkExtension)) { + true + } else { + false + } + } val withInitBucketCol = bucketMode match { + case BUCKET_UNAWARE => data case CROSS_PARTITION if !data.schema.fieldNames.contains(ROW_KIND_COL) => data .withColumn(ROW_KIND_COL, lit(RowKind.INSERT.toByteValue)) @@ -229,10 +245,26 @@ case class PaimonSparkWriter(table: FileStoreTable) { writeWithoutBucket(data) case HASH_FIXED => - // Topology: input -> bucket-assigner -> shuffle by partition & bucket - writeWithBucketProcessor( - withInitBucketCol, - CommonBucketProcessor(table, bucketColIdx, encoderGroupWithBucketCol)) + if (!paimonExtensionEnabled) { + // Topology: input -> bucket-assigner -> shuffle by partition & bucket + writeWithBucketProcessor( + withInitBucketCol, + CommonBucketProcessor(table, bucketColIdx, encoderGroupWithBucketCol)) + } else { + // Topology: input -> shuffle by partition & bucket + val bucketNumber = table.coreOptions().bucket() + val bucketKeyCol = tableSchema + .bucketKeys() + .asScala + .map(tableSchema.fieldNames().indexOf(_)) + .map(x => col(data.schema.fieldNames(x))) + .toSeq + val args = Seq(lit(bucketNumber)) ++ bucketKeyCol + val repartitioned = + repartitionByPartitionsAndBucket( + data.withColumn(BUCKET_COL, call_udf(BucketExpression.FIXED_BUCKET, args: _*))) + writeWithBucket(repartitioned) + } case _ => throw new UnsupportedOperationException(s"Spark doesn't support $bucketMode mode.") diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala index bfd337580dbf..4ecff93ea679 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala @@ -21,6 +21,7 @@ package org.apache.paimon.spark.extensions import org.apache.paimon.spark.catalyst.analysis.{PaimonAnalysis, PaimonDeleteTable, PaimonIncompatiblePHRRules, PaimonIncompatibleResolutionRules, PaimonMergeInto, PaimonPostHocResolutionRules, PaimonProcedureResolver, PaimonUpdateTable, PaimonViewResolver, ReplacePaimonFunctions} import org.apache.paimon.spark.catalyst.optimizer.{EvalSubqueriesForDeleteTable, MergePaimonScalarSubqueries} import org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions +import org.apache.paimon.spark.commands.BucketExpression import org.apache.paimon.spark.execution.PaimonStrategy import org.apache.paimon.spark.execution.adaptive.DisableUnnecessaryPaimonBucketedScan @@ -59,6 +60,11 @@ class PaimonSparkSessionExtensions extends (SparkSessionExtensions => Unit) { PaimonTableValuedFunctions.getTableValueFunctionInjection(fnName)) } + // scalar function extensions + BucketExpression.supportedFnNames.foreach { + fnName => extensions.injectFunction(BucketExpression.getFunctionInjection(fnName)) + } + // optimization rules extensions.injectOptimizerRule(_ => EvalSubqueriesForDeleteTable) extensions.injectOptimizerRule(_ => MergePaimonScalarSubqueries) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkRowUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkRowUtils.scala index caedadf1f2cf..5b07f9195221 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkRowUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/util/SparkRowUtils.scala @@ -21,6 +21,7 @@ package org.apache.paimon.spark.util import org.apache.paimon.types.RowKind import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType object SparkRowUtils { @@ -33,6 +34,14 @@ object SparkRowUtils { } } + def getRowKind(row: InternalRow, rowkindColIdx: Int): RowKind = { + if (rowkindColIdx != -1) { + RowKind.fromByteValue(row.getByte(rowkindColIdx)) + } else { + RowKind.INSERT + } + } + def getFieldIndex(schema: StructType, colName: String): Int = { try { schema.fieldIndex(colName) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/paimon/shims/SparkShim.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/paimon/shims/SparkShim.scala index 3d29d7c3c577..c460687f73cd 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/paimon/shims/SparkShim.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/paimon/shims/SparkShim.scala @@ -23,10 +23,12 @@ import org.apache.paimon.spark.data.{SparkArrayData, SparkInternalRow} import org.apache.paimon.types.{DataType, RowType} import org.apache.spark.sql.{Column, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -66,6 +68,10 @@ trait SparkShim { // for variant def toPaimonVariant(o: Object): Variant + def toPaimonVariant(row: InternalRow, pos: Int): Variant + + def toPaimonVariant(array: ArrayData, pos: Int): Variant + def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean def SparkVariantType(): org.apache.spark.sql.types.DataType diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteITCase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteITCase.scala new file mode 100644 index 000000000000..96abb1b84019 --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteITCase.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase + +import org.apache.spark.sql.Row +import org.junit.jupiter.api.Assertions + +import java.sql.Timestamp +import java.time.LocalDateTime + +class SparkWriteITCase extends PaimonSparkTestBase { + + import testImplicits._ + + test("Paimon Write: AllTypes") { + withTable("AllTypesTable") { + val createTableSQL = + """ + |CREATE TABLE AllTypesTable ( + | byte_col BYTE NOT NULL, + | short_col SHORT, + | int_col INT NOT NULL, + | long_col LONG, + | float_col FLOAT, + | double_col DOUBLE NOT NULL, + | decimal_col DECIMAL(10,2), + | string_col STRING, + | binary_col BINARY, + | boolean_col BOOLEAN NOT NULL, + | date_col DATE, + | timestamp_col TIMESTAMP, + | timestamp_ntz_col TIMESTAMP_NTZ, + | array_col ARRAY, + | map_col MAP, + | struct_col STRUCT + |) TBLPROPERTIES ( + | 'bucket' = '2', + | 'bucket-key' = 'int_col' + |) + |""".stripMargin + sql(createTableSQL) + + sql(""" + |INSERT INTO AllTypesTable VALUES ( + | 1Y, -- byte_col (NOT NULL) + | 100S, -- short_col + | 42, -- int_col (NOT NULL) + | 9999999999L, -- long_col + | 3.14F, -- float_col + | CAST(2.71828 as double), -- double_col (NOT NULL) + | CAST('123.45' AS DECIMAL(10,2)), -- decimal_col + | 'test_string', -- string_col + | unhex('0001'), -- binary_col + | true, -- boolean_col (NOT NULL) + | DATE '2023-10-01', -- date_col + | TIMESTAMP '2023-10-01 12:34:56', -- timestamp_col + | TIMESTAMP_NTZ '2023-10-01 12:34:56', -- timestamp_ntz_col + | ARRAY(1, 2, 3), -- array_col + | MAP('key1', 1, 'key2', 2), -- map_col + | NAMED_STRUCT('f1', 10, 'f2', 'struct_field') -- struct_col + |) + |""".stripMargin) + + checkAnswer( + sql("SELECT * FROM AllTypesTable"), + Row( + 1.toByte, // byte_col + 100.toShort, // short_col + 42, // int_col + 9999999999L, // long_col + 3.14f, // float_col + 2.71828, // double_col + new java.math.BigDecimal("123.45"), // decimal_col + "test_string", // string_col + Array(0x00, 0x01), // binary_col + true, // boolean_col + java.sql.Date.valueOf("2023-10-01"), // date_col + java.sql.Timestamp.valueOf("2023-10-01 12:34:56"), // timestamp_col + LocalDateTime.parse("2023-10-01T12:34:56"), // timestamp_ntz_col + Array(1, 2, 3), // array_col + Map("key1" -> 1, "key2" -> 2), // map_col + Row(10, "struct_field") // struct_col + ) :: Nil + ) + } + } + + test("Paimon Write : Nested type") { + withTable("NestedTypesTable") { + val createTableSQL = + """ + |CREATE TABLE NestedTypesTable ( + | id INT NOT NULL, + | map_col MAP>, + | struct_col STRUCT< + | name: STRING, + | details: MAP, + | scores: ARRAY + | >, + | nested_array_col ARRAY, + | sub_array: ARRAY + | >> NOT NULL + |) + |""".stripMargin + spark.sql(createTableSQL) + + spark.sql(""" + |INSERT INTO NestedTypesTable VALUES + |( + | 1, + | MAP('key1', ARRAY(1, 2, 3), 'key2', ARRAY(4, 5)), -- map_col + | STRUCT( -- struct_col + | 'user1', + | MAP('age', 25, 'score', 99), + | ARRAY(CAST(90.5 as double), CAST(88.0 as double)) + | ), + | ARRAY( -- nested_array_col + | STRUCT(MAP('a', 1), ARRAY(10, 20)), + | STRUCT(MAP('b', 2, 'c', 3), ARRAY(30)) + | ) + |) + |""".stripMargin) + + checkAnswer( + spark.sql("SELECT * FROM NestedTypesTable WHERE id = 1"), + Row( + 1, // id + Map( // map_col + "key1" -> Seq(1, 2, 3), + "key2" -> Seq(4, 5)), + Row( // struct_col + "user1", + Map("age" -> 25, "score" -> 99), + Seq(90.5, 88.0)), + Seq( // nested_array_col + Row(Map("a" -> 1), Seq(10, 20)), + Row(Map("b" -> 2, "c" -> 3), Seq(30))) + ) :: Nil + ) + } + } + + test("Paimon write: nested type with timestamp/timestamp_ntz") { + withTable("NestedTimestampTable") { + val createTableSQL = + """ + |CREATE TABLE NestedTimestampTable ( + | id INT NOT NULL, + | struct_col STRUCT< + | ts_ltz: TIMESTAMP, + | ts_ntz: TIMESTAMP_NTZ, + | map_field: MAP + | >, + | array_col ARRAY> NOT NULL + |) + |""".stripMargin + spark.sql(createTableSQL) + + spark.sql(""" + |INSERT INTO NestedTimestampTable VALUES ( + | 1, + | STRUCT( + | TIMESTAMP '2023-10-01 12:00:00', + | TIMESTAMP_NTZ '2023-10-01 12:00:00', + | MAP('ntz1', TIMESTAMP_NTZ '2023-10-01 08:00:00') + | ), + | ARRAY( + | STRUCT( + | TIMESTAMP '2023-10-01 13:00:00', + | TIMESTAMP_NTZ '2023-10-01 13:00:00' + | ) + | ) + |) + |""".stripMargin) + + val expectedTsLtz = Timestamp.valueOf("2023-10-01 12:00:00") + val expectedTsNtz = LocalDateTime.parse("2023-10-01T12:00:00") + checkAnswer( + spark.sql("SELECT struct_col.ts_ltz, struct_col.ts_ntz FROM NestedTimestampTable"), + Row(expectedTsLtz, expectedTsNtz) :: Nil + ) + + val mapValue = spark + .sql("SELECT struct_col.map_field['ntz1'] FROM NestedTimestampTable") + .collect()(0) + .getAs[LocalDateTime](0) + Assertions.assertEquals( + LocalDateTime.parse("2023-10-01T08:00:00"), + mapValue + ) + + // timestamp in array + checkAnswer( + spark.sql("SELECT array_col[0].ts_ltz, array_col[0].ts_ntz FROM NestedTimestampTable"), + Row( + Timestamp.valueOf("2023-10-01 13:00:00"), + LocalDateTime.parse("2023-10-01T13:00:00") + ) :: Nil + ) + + } + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteWithNoExtensionITCase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteWithNoExtensionITCase.scala new file mode 100644 index 000000000000..9daa9f7d3b39 --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteWithNoExtensionITCase.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.spark.SparkConf + +/** Test for spark writer with extension disabled. */ +class SparkWriteWithNoExtensionITCase extends SparkWriteITCase { + + /** Disable the spark extension. */ + override protected def sparkConf: SparkConf = { + super.sparkConf.remove("spark.sql.extensions") + } +} diff --git a/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark3Shim.scala b/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark3Shim.scala index 9b96a64fb1c4..62c52ab57ba7 100644 --- a/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark3Shim.scala +++ b/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark3Shim.scala @@ -25,10 +25,12 @@ import org.apache.paimon.spark.data.{Spark3ArrayData, Spark3InternalRow, SparkAr import org.apache.paimon.types.{DataType, RowType} import org.apache.spark.sql.{Column, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -79,4 +81,10 @@ class Spark3Shim extends SparkShim { override def SparkVariantType(): org.apache.spark.sql.types.DataType = throw new UnsupportedOperationException() + + override def toPaimonVariant(row: InternalRow, pos: Int): Variant = + throw new UnsupportedOperationException() + + override def toPaimonVariant(array: ArrayData, pos: Int): Variant = + throw new UnsupportedOperationException() } diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark4Shim.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark4Shim.scala index 33eefc7d568c..8a0dad7733c9 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark4Shim.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark4Shim.scala @@ -25,10 +25,12 @@ import org.apache.paimon.spark.data.{Spark4ArrayData, Spark4InternalRow, SparkAr import org.apache.paimon.types.{DataType, RowType} import org.apache.spark.sql.{Column, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.ExpressionUtils @@ -81,6 +83,16 @@ class Spark4Shim extends SparkShim { new GenericVariant(v.getValue, v.getMetadata) } + override def toPaimonVariant(row: InternalRow, pos: Int): Variant = { + val v = row.getVariant(pos) + new GenericVariant(v.getValue, v.getMetadata) + } + + override def toPaimonVariant(array: ArrayData, pos: Int): Variant = { + val v = array.getVariant(pos) + new GenericVariant(v.getValue, v.getMetadata) + } + override def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean = dataType.isInstanceOf[VariantType]