diff --git a/.gitignore b/.gitignore
index 8e7d93ebaccc..f164da86f55d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,6 +18,7 @@ target
*.iws
.java-version
dependency-reduced-pom.xml
+benchmark/*
### VS Code ###
.vscode/
diff --git a/paimon-benchmark/paimon-spark-benchmark/pom.xml b/paimon-benchmark/paimon-spark-benchmark/pom.xml
new file mode 100644
index 000000000000..2d27438512aa
--- /dev/null
+++ b/paimon-benchmark/paimon-spark-benchmark/pom.xml
@@ -0,0 +1,162 @@
+
+
+
+ 4.0.0
+
+ paimon-parent
+ org.apache.paimon
+ 1.1-SNAPSHOT
+
+
+ paimon-spark-benchmark
+ jar
+ Apache Paimon Benchmarks
+
+
+ 3.5.4
+ paimon-spark-benchmark
+ 1.37
+
+
+
+
+ org.apache.paimon
+ paimon-spark-3.5
+ ${project.version}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+
+
+
+
+ org.openjdk.jmh
+ jmh-core
+ ${jmh.version}
+
+
+
+ org.openjdk.jmh
+ jmh-generator-annprocess
+ ${jmh.version}
+ provided
+
+
+
+ org.apache.logging.log4j
+ log4j-api
+ ${log4j.version}
+
+
+
+ org.apache.logging.log4j
+ log4j-core
+ ${log4j.version}
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+ 2.2
+
+
+ package
+
+ shade
+
+
+ ${uberjar.name}
+
+
+ org.openjdk.jmh.Main
+
+
+
+
+
+ *:*
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+
+
+
+ analyze-only
+
+
+ true
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/paimon-benchmark/paimon-spark-benchmark/src/main/java/org.apache.paimon.spark.source/PaimonSourceWriteBenchmark.java b/paimon-benchmark/paimon-spark-benchmark/src/main/java/org.apache.paimon.spark.source/PaimonSourceWriteBenchmark.java
new file mode 100644
index 000000000000..079192d84b63
--- /dev/null
+++ b/paimon-benchmark/paimon-spark-benchmark/src/main/java/org.apache.paimon.spark.source/PaimonSourceWriteBenchmark.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.source;
+
+import org.apache.paimon.spark.SparkCatalog;
+import org.apache.paimon.spark.SparkSQLProperties;
+
+import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap;
+import org.apache.paimon.shade.guava30.com.google.common.collect.Maps;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.internal.SQLConf;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Threads;
+import org.openjdk.jmh.annotations.Warmup;
+
+import java.io.File;
+import java.util.Map;
+import java.util.UUID;
+
+import static org.apache.spark.sql.functions.expr;
+
+/**
+ * Paimon source spark write benchmark.
+ *
+ *
Usage:
+ *
+ *
mvn clean install -pl ':paimon-spark-benchmark' -DskipTests=true
+ *
+ *
java -jar ./paimon-benchmark/paimon-spark-benchmark/target/paimon-spark-benchmark.jar
+ * org.apache.paimon.spark.source.PaimonSourceWriteBenchmark -o
+ * benchmark/paimon-source-write-result.txt
+ */
+@Fork(1)
+@State(Scope.Benchmark)
+@Warmup(iterations = 3)
+@Measurement(iterations = 5)
+@BenchmarkMode(Mode.SingleShotTime)
+public class PaimonSourceWriteBenchmark {
+ public static final String PAIMON_TABLE_NAME = "paimon_table";
+ private static final int NUM_ROWS = 3_000_000;
+
+ private final String warehousePath = warehousePath();
+
+ private SparkSession spark;
+
+ protected void setup() {
+ spark =
+ SparkSession.builder()
+ .config("spark.ui.enabled", false)
+ .master("local")
+ .config("spark.sql.catalog.paimon", SparkCatalog.class.getName())
+ .config("spark.sql.catalog.paimon.warehouse", warehousePath)
+ .getOrCreate();
+
+ spark.sql("CREATE DATABASE paimon.db");
+ spark.sql("USE paimon.db");
+
+ spark.sql(
+ String.format(
+ "CREATE TABLE %s ("
+ + "intCol INT NOT NULL, "
+ + "longCol BIGINT NOT NULL, "
+ + "floatCol FLOAT NOT NULL, "
+ + "doubleCol DOUBLE NOT NULL, "
+ + "decimalCol DECIMAL(20, 5) NOT NULL, "
+ + "stringCol1 STRING NOT NULL, "
+ + "stringCol2 STRING NOT NULL, "
+ + "stringCol3 STRING NOT NULL"
+ + ") using paimon "
+ + "TBLPROPERTIES('primary-key'='intCol,stringCol2', 'bucket'='2')",
+ PAIMON_TABLE_NAME));
+ }
+
+ protected String warehousePath() {
+ Path warehosuePath = new Path("/tmp", "paimon-warehouse-" + UUID.randomUUID());
+ return warehosuePath.toString();
+ }
+
+ @Setup
+ public void setupBenchmark() {
+ setup();
+ }
+
+ @TearDown
+ public void tearDown() {
+ File warehouseDir = new File(warehousePath);
+ deleteFile(warehouseDir);
+ }
+
+ public static boolean deleteFile(File file) {
+ if (file.isDirectory()) {
+ File[] files = file.listFiles();
+ if (files != null) {
+ for (File child : files) {
+ deleteFile(child);
+ }
+ }
+ }
+ return file.delete();
+ }
+
+ @Benchmark
+ @Threads(1)
+ public void v1Write() throws NoSuchTableException {
+ benchmarkData().writeTo(PAIMON_TABLE_NAME).append();
+ }
+
+ @Benchmark
+ @Threads(1)
+ public void v2Write() {
+ Map conf = ImmutableMap.of(SparkSQLProperties.USE_V2_WRITE, "true");
+ withSQLConf(
+ conf,
+ () -> {
+ try {
+ benchmarkData().writeTo(PAIMON_TABLE_NAME).append();
+ } catch (NoSuchTableException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ private Dataset benchmarkData() {
+ return spark.range(NUM_ROWS)
+ .withColumn("intCol", expr("CAST(id AS INT)"))
+ .withColumn("longCol", expr("CAST(id AS LONG)"))
+ .withColumn("floatCol", expr("CAST(id AS FLOAT)"))
+ .withColumn("doubleCol", expr("CAST(id AS DOUBLE)"))
+ .withColumn("decimalCol", expr("CAST(id AS DECIMAL(20, 5))"))
+ .withColumn("stringCol1", expr("CAST(id AS STRING)"))
+ .withColumn("stringCol2", expr("CAST(id AS STRING)"))
+ .withColumn("stringCol3", expr("CAST(id AS STRING)"))
+ .drop("id")
+ .coalesce(1);
+ }
+
+ private void withSQLConf(Map sparkSqlConf, Action action) {
+ SQLConf sqlConf = SQLConf.get();
+
+ Map currentConfValues = Maps.newHashMap();
+ sparkSqlConf
+ .keySet()
+ .forEach(
+ confKey -> {
+ if (sqlConf.contains(confKey)) {
+ String currentConfValue = sqlConf.getConfString(confKey);
+ currentConfValues.put(confKey, currentConfValue);
+ }
+ });
+
+ sparkSqlConf.forEach(
+ (confKey, confValue) -> {
+ if (SQLConf.isStaticConfigKey(confKey)) {
+ throw new RuntimeException(
+ "Cannot modify the value of a static config: " + confKey);
+ }
+ sqlConf.setConfString(confKey, confValue);
+ });
+ try {
+ action.invoke();
+ } finally {
+ sparkSqlConf.forEach(
+ (confKey, confValue) -> {
+ if (currentConfValues.containsKey(confKey)) {
+ sqlConf.setConfString(confKey, currentConfValues.get(confKey));
+ } else {
+ sqlConf.unsetConf(confKey);
+ }
+ });
+ }
+ }
+
+ /** Action functional interface. */
+ @FunctionalInterface
+ public interface Action {
+ void invoke();
+ }
+}
diff --git a/paimon-benchmark/pom.xml b/paimon-benchmark/pom.xml
index 75e9a1d51ba4..e8093ac79914 100644
--- a/paimon-benchmark/pom.xml
+++ b/paimon-benchmark/pom.xml
@@ -36,6 +36,7 @@ under the License.
paimon-cluster-benchmark
paimon-micro-benchmarks
+ paimon-spark-benchmark
diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConversions.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConversions.java
new file mode 100644
index 000000000000..c4f0e475a358
--- /dev/null
+++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConversions.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark;
+
+import org.apache.paimon.utils.Preconditions;
+
+import org.apache.spark.sql.types.Decimal;
+
+import java.util.function.Function;
+
+/** Data conversions between Spark and Paimon. */
+public class SparkConversions {
+ private SparkConversions() {}
+
+ public static Function decimalSparkToPaimon(
+ int precision, int scale) {
+ Preconditions.checkArgument(
+ precision <= 38,
+ "Decimals with precision larger than 38 are not supported: %s",
+ precision);
+
+ if (org.apache.paimon.data.Decimal.isCompact(precision)) {
+ return sparkDecimal ->
+ org.apache.paimon.data.Decimal.fromUnscaledLong(
+ sparkDecimal.toUnscaledLong(), precision, scale);
+ } else {
+ return sparkDecimal ->
+ org.apache.paimon.data.Decimal.fromBigDecimal(
+ sparkDecimal.toJavaBigDecimal(), precision, scale);
+ }
+ }
+}
diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
new file mode 100644
index 000000000000..deae55b4f4d2
--- /dev/null
+++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark;
+
+import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.Decimal;
+import org.apache.paimon.data.InternalArray;
+import org.apache.paimon.data.InternalMap;
+import org.apache.paimon.data.Timestamp;
+import org.apache.paimon.data.variant.Variant;
+import org.apache.paimon.types.ArrayType;
+import org.apache.paimon.types.DataType;
+import org.apache.paimon.types.MapType;
+import org.apache.paimon.types.RowKind;
+import org.apache.paimon.types.RowType;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.MapData;
+
+import java.io.Serializable;
+
+/** Wrapper of Spark {@link InternalRow}. */
+public class SparkInternalRowWrapper implements org.apache.paimon.data.InternalRow, Serializable {
+ private final RowType paimonRowType;
+ private final InternalRow sparkInternalRow;
+ private final RowKind rowKind;
+
+ public SparkInternalRowWrapper(RowType paimonType, InternalRow sparkInternalRow) {
+ this(paimonType, RowKind.INSERT, sparkInternalRow);
+ }
+
+ public SparkInternalRowWrapper(
+ RowType paimonType, RowKind rowKind, InternalRow sparkInternalRow) {
+ this.paimonRowType = paimonType;
+ this.rowKind = rowKind;
+ this.sparkInternalRow = sparkInternalRow;
+ }
+
+ @Override
+ public int getFieldCount() {
+ return sparkInternalRow.numFields();
+ }
+
+ @Override
+ public RowKind getRowKind() {
+ return rowKind;
+ }
+
+ @Override
+ public void setRowKind(RowKind rowKind) {
+ throw new UnsupportedOperationException(
+ "SparkInternalRowWrapper does not support modifying row kind");
+ }
+
+ @Override
+ public boolean isNullAt(int i) {
+ return sparkInternalRow.isNullAt(i);
+ }
+
+ @Override
+ public boolean getBoolean(int i) {
+ return sparkInternalRow.getBoolean(i);
+ }
+
+ @Override
+ public byte getByte(int i) {
+ return sparkInternalRow.getByte(i);
+ }
+
+ @Override
+ public short getShort(int i) {
+ return sparkInternalRow.getShort(i);
+ }
+
+ @Override
+ public int getInt(int i) {
+ return sparkInternalRow.getInt(i);
+ }
+
+ @Override
+ public long getLong(int i) {
+ return sparkInternalRow.getLong(i);
+ }
+
+ @Override
+ public float getFloat(int i) {
+ return sparkInternalRow.getFloat(i);
+ }
+
+ @Override
+ public double getDouble(int i) {
+ return sparkInternalRow.getDouble(i);
+ }
+
+ @Override
+ public BinaryString getString(int i) {
+ return sparkInternalRow.isNullAt(i)
+ ? null
+ : BinaryString.fromBytes(sparkInternalRow.getUTF8String(i).getBytes());
+ }
+
+ @Override
+ public Decimal getDecimal(int i, int precision, int scale) {
+ return sparkInternalRow.isNullAt(i)
+ ? null
+ : Decimal.fromBigDecimal(
+ sparkInternalRow.getDecimal(i, precision, scale).toJavaBigDecimal(),
+ precision,
+ scale);
+ }
+
+ @Override
+ public Timestamp getTimestamp(int i, int precision) {
+ return sparkInternalRow.isNullAt(i)
+ ? null
+ : Timestamp.fromMicros(sparkInternalRow.getLong(i));
+ }
+
+ @Override
+ public byte[] getBinary(int i) {
+ return sparkInternalRow.getBinary(i);
+ }
+
+ @Override
+ public Variant getVariant(int pos) {
+ // return SparkShimLoader.getSparkShim().toPaimonVariant(row.getAs(i));
+ // TODO
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public InternalArray getArray(int i) {
+ return sparkInternalRow.isNullAt(i)
+ ? null
+ : new SparkArrayDataWrapper(
+ ((ArrayType) paimonRowType.getTypeAt(i)).getElementType(),
+ sparkInternalRow.getArray(i));
+ }
+
+ @Override
+ public InternalMap getMap(int i) {
+ return sparkInternalRow.isNullAt(i)
+ ? null
+ : toPaimonInternalMap(
+ (MapType) paimonRowType.getTypeAt(i), sparkInternalRow.getMap(i));
+ }
+
+ @Override
+ public org.apache.paimon.data.InternalRow getRow(int i, int numFields) {
+ return sparkInternalRow.isNullAt(i)
+ ? null
+ : new SparkInternalRowWrapper(
+ (RowType) paimonRowType.getTypeAt(i),
+ sparkInternalRow.getStruct(i, numFields));
+ }
+
+ private static InternalMap toPaimonInternalMap(MapType mapType, MapData sparkMapData) {
+ SparkArrayDataWrapper keyArrayWrapper =
+ new SparkArrayDataWrapper(mapType.getKeyType(), sparkMapData.keyArray());
+ SparkArrayDataWrapper valueArrayWrapper =
+ new SparkArrayDataWrapper(mapType.getValueType(), sparkMapData.valueArray());
+ return new InternalMap() {
+ @Override
+ public int size() {
+ return sparkMapData.numElements();
+ }
+
+ @Override
+ public InternalArray keyArray() {
+ return keyArrayWrapper;
+ }
+
+ @Override
+ public InternalArray valueArray() {
+ return valueArrayWrapper;
+ }
+ };
+ }
+
+ private static class SparkArrayDataWrapper implements InternalArray {
+
+ private final DataType elementPaimonType;
+ private final ArrayData sparkArray;
+
+ private SparkArrayDataWrapper(DataType elementPaimonType, ArrayData sparkArray) {
+ this.sparkArray = sparkArray;
+ this.elementPaimonType = elementPaimonType;
+ }
+
+ @Override
+ public int size() {
+ return sparkArray.numElements();
+ }
+
+ @Override
+ public boolean isNullAt(int i) {
+ return sparkArray.isNullAt(i);
+ }
+
+ @Override
+ public boolean getBoolean(int i) {
+ return sparkArray.getBoolean(i);
+ }
+
+ @Override
+ public byte getByte(int i) {
+ return sparkArray.getByte(i);
+ }
+
+ @Override
+ public short getShort(int i) {
+ return sparkArray.getShort(i);
+ }
+
+ @Override
+ public int getInt(int i) {
+ return sparkArray.getInt(i);
+ }
+
+ @Override
+ public long getLong(int i) {
+ return sparkArray.getLong(i);
+ }
+
+ @Override
+ public float getFloat(int i) {
+ return sparkArray.getFloat(i);
+ }
+
+ @Override
+ public double getDouble(int i) {
+ return sparkArray.getDouble(i);
+ }
+
+ @Override
+ public BinaryString getString(int i) {
+ return sparkArray.isNullAt(i)
+ ? null
+ : BinaryString.fromBytes(sparkArray.getUTF8String(i).getBytes());
+ }
+
+ @Override
+ public Decimal getDecimal(int i, int precision, int scale) {
+ return sparkArray.isNullAt(i)
+ ? null
+ : Decimal.fromBigDecimal(
+ sparkArray.getDecimal(i, precision, scale).toJavaBigDecimal(),
+ precision,
+ scale);
+ }
+
+ @Override
+ public Timestamp getTimestamp(int i, int precision) {
+ return sparkArray.isNullAt(i) ? null : Timestamp.fromMicros(sparkArray.getLong(i));
+ }
+
+ @Override
+ public byte[] getBinary(int i) {
+ return sparkArray.getBinary(i);
+ }
+
+ @Override
+ public Variant getVariant(int pos) {
+ // TODO
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public InternalArray getArray(int i) {
+ return sparkArray.isNullAt(i)
+ ? null
+ : new SparkArrayDataWrapper(
+ ((ArrayType) elementPaimonType).getElementType(),
+ sparkArray.getArray(i));
+ }
+
+ @Override
+ public InternalMap getMap(int i) {
+ return sparkArray.isNullAt(i)
+ ? null
+ : toPaimonInternalMap((MapType) elementPaimonType, sparkArray.getMap(i));
+ }
+
+ @Override
+ public org.apache.paimon.data.InternalRow getRow(int i, int numFields) {
+ return sparkArray.isNullAt(i)
+ ? null
+ : new SparkInternalRowWrapper(
+ (RowType) elementPaimonType, sparkArray.getStruct(i, numFields));
+ }
+
+ @Override
+ public boolean[] toBooleanArray() {
+ return sparkArray.toBooleanArray();
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return sparkArray.toByteArray();
+ }
+
+ @Override
+ public short[] toShortArray() {
+ return sparkArray.toShortArray();
+ }
+
+ @Override
+ public int[] toIntArray() {
+ return sparkArray.toIntArray();
+ }
+
+ @Override
+ public long[] toLongArray() {
+ return sparkArray.toLongArray();
+ }
+
+ @Override
+ public float[] toFloatArray() {
+ return sparkArray.toFloatArray();
+ }
+
+ @Override
+ public double[] toDoubleArray() {
+ return sparkArray.toDoubleArray();
+ }
+ }
+}
diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteRequirement.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteRequirement.java
new file mode 100644
index 000000000000..8d4a39becca7
--- /dev/null
+++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteRequirement.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark;
+
+import org.apache.paimon.table.BucketMode;
+import org.apache.paimon.table.BucketSpec;
+import org.apache.paimon.table.FileStoreTable;
+
+import org.apache.paimon.shade.guava30.com.google.common.collect.Lists;
+
+import org.apache.spark.sql.connector.distributions.ClusteredDistribution;
+import org.apache.spark.sql.connector.distributions.Distribution;
+import org.apache.spark.sql.connector.distributions.Distributions;
+import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.connector.expressions.Expressions;
+import org.apache.spark.sql.connector.expressions.SortOrder;
+
+import java.util.List;
+
+/** Distribution and ordering requirements of spark write. */
+public class SparkWriteRequirement {
+ public static final SparkWriteRequirement EMPTY =
+ new SparkWriteRequirement(Distributions.unspecified());
+
+ private static final SortOrder[] EMPTY_ORDERING = new SortOrder[0];
+ private final Distribution distribution;
+
+ public static SparkWriteRequirement of(FileStoreTable table) {
+ BucketSpec bucketSpec = table.bucketSpec();
+ BucketMode bucketMode = bucketSpec.getBucketMode();
+ switch (bucketMode) {
+ case HASH_FIXED:
+ case BUCKET_UNAWARE:
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ String.format("Unsupported bucket mode %s", bucketMode));
+ }
+
+ List clusteringExpressions = Lists.newArrayList();
+
+ List partitionKeys = table.schema().partitionKeys();
+ for (String partitionKey : partitionKeys) {
+ clusteringExpressions.add(Expressions.identity(quote(partitionKey)));
+ }
+
+ if (bucketMode == BucketMode.HASH_FIXED) {
+ String[] quotedBucketKeys =
+ bucketSpec.getBucketKeys().stream()
+ .map(SparkWriteRequirement::quote)
+ .toArray(String[]::new);
+ clusteringExpressions.add(
+ Expressions.bucket(bucketSpec.getNumBuckets(), quotedBucketKeys));
+ }
+
+ if (clusteringExpressions.isEmpty()) {
+ return EMPTY;
+ }
+
+ ClusteredDistribution distribution =
+ Distributions.clustered(
+ clusteringExpressions.toArray(
+ clusteringExpressions.toArray(new Expression[0])));
+
+ return new SparkWriteRequirement(distribution);
+ }
+
+ private static String quote(String columnName) {
+ return String.format("`%s`", columnName.replace("`", "``"));
+ }
+
+ private SparkWriteRequirement(Distribution distribution) {
+ this.distribution = distribution;
+ }
+
+ public Distribution distribution() {
+ return distribution;
+ }
+
+ public SortOrder[] ordering() {
+ return EMPTY_ORDERING;
+ }
+}
diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/BucketFunction.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/BucketFunction.java
new file mode 100644
index 000000000000..34bb769f0ef9
--- /dev/null
+++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/BucketFunction.java
@@ -0,0 +1,724 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.catalog.functions;
+
+import org.apache.paimon.data.BinaryRow;
+import org.apache.paimon.data.BinaryRowWriter;
+import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.Timestamp;
+import org.apache.paimon.spark.SparkConversions;
+import org.apache.paimon.table.BucketMode;
+import org.apache.paimon.table.FileStoreTable;
+import org.apache.paimon.types.ArrayType;
+import org.apache.paimon.types.MapType;
+import org.apache.paimon.types.RowType;
+import org.apache.paimon.types.VariantType;
+
+import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.BinaryType;
+import org.apache.spark.sql.types.BooleanType;
+import org.apache.spark.sql.types.ByteType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.ShortType;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.TimestampNTZType;
+import org.apache.spark.sql.types.TimestampType;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/** A Spark function implementation for the Paimon bucket transform. */
+public class BucketFunction implements UnboundFunction {
+ private static final int NUM_BUCKETS_ORDINAL = 0;
+ private static final int SPARK_TIMESTAMP_PRECISION = 6;
+
+ private static final Map> BUCKET_FUNCTIONS;
+
+ static {
+ ImmutableMap.Builder> builder =
+ ImmutableMap.builder();
+ builder.put("BucketBoolean", BucketBoolean.class);
+ builder.put("BucketByte", BucketByte.class);
+ builder.put("BucketShort", BucketShort.class);
+ builder.put("BucketInteger", BucketInteger.class);
+ builder.put("BucketLong", BucketLong.class);
+ builder.put("BucketFloat", BucketFloat.class);
+ builder.put("BucketDouble", BucketDouble.class);
+ builder.put("BucketString", BucketString.class);
+ builder.put("BucketDecimal", BucketDecimal.class);
+ builder.put("BucketTimestamp", BucketTimestamp.class);
+ builder.put("BucketBinary", BucketBinary.class);
+
+ // Joint bucket fields of common types
+ builder.put("BucketIntegerInteger", BucketIntegerInteger.class);
+ builder.put("BucketIntegerLong", BucketIntegerLong.class);
+ builder.put("BucketIntegerString", BucketIntegerString.class);
+ builder.put("BucketLongInteger", BucketLongInteger.class);
+ builder.put("BucketLongLong", BucketLongLong.class);
+ builder.put("BucketLongString", BucketLongString.class);
+ builder.put("BucketStringInteger", BucketStringInteger.class);
+ builder.put("BucketStringLong", BucketStringLong.class);
+ builder.put("BucketStringString", BucketStringString.class);
+
+ BUCKET_FUNCTIONS = builder.build();
+ }
+
+ public static boolean supportsTable(FileStoreTable table) {
+ if (table.bucketMode() != BucketMode.HASH_FIXED) {
+ return false;
+ }
+
+ return table.schema().logicalBucketKeyType().getFieldTypes().stream()
+ .allMatch(BucketFunction::supportsType);
+ }
+
+ private static boolean supportsType(org.apache.paimon.types.DataType type) {
+ if (type instanceof ArrayType
+ || type instanceof MapType
+ || type instanceof RowType
+ || type instanceof VariantType) {
+ return false;
+ }
+
+ if (type instanceof org.apache.paimon.types.TimestampType) {
+ return ((org.apache.paimon.types.TimestampType) type).getPrecision()
+ == SPARK_TIMESTAMP_PRECISION;
+ }
+
+ if (type instanceof org.apache.paimon.types.LocalZonedTimestampType) {
+ return ((org.apache.paimon.types.LocalZonedTimestampType) type).getPrecision()
+ == SPARK_TIMESTAMP_PRECISION;
+ }
+
+ return true;
+ }
+
+ @Override
+ public BoundFunction bind(StructType inputType) {
+ StructField[] fields = inputType.fields();
+
+ StringBuilder classNameBuilder = new StringBuilder("Bucket");
+ DataType[] bucketKeyTypes = new DataType[fields.length - 1];
+ for (int i = 1; i < fields.length; i += 1) {
+ DataType dataType = fields[i].dataType();
+ bucketKeyTypes[i - 1] = dataType;
+ if (dataType instanceof BooleanType) {
+ classNameBuilder.append("Boolean");
+ } else if (dataType instanceof ByteType) {
+ classNameBuilder.append("Byte");
+ } else if (dataType instanceof ShortType) {
+ classNameBuilder.append("Short");
+ } else if (dataType instanceof IntegerType) {
+ classNameBuilder.append("Integer");
+ } else if (dataType instanceof LongType) {
+ classNameBuilder.append("Long");
+ } else if (dataType instanceof FloatType) {
+ classNameBuilder.append("Float");
+ } else if (dataType instanceof DoubleType) {
+ classNameBuilder.append("Double");
+ } else if (dataType instanceof StringType) {
+ classNameBuilder.append("String");
+ } else if (dataType instanceof DecimalType) {
+ classNameBuilder.append("Decimal");
+ } else if (dataType instanceof TimestampType || dataType instanceof TimestampNTZType) {
+ classNameBuilder.append("Timestamp");
+ } else if (dataType instanceof BinaryType) {
+ classNameBuilder.append("Binary");
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported type: " + dataType.simpleString());
+ }
+ }
+
+ Class extends BucketGeneric> bucketClass =
+ BUCKET_FUNCTIONS.getOrDefault(classNameBuilder.toString(), BucketGeneric.class);
+
+ try {
+ return bucketClass
+ .getConstructor(DataType[].class)
+ .newInstance(
+ (Object)
+ bucketKeyTypes /* cast DataType[] to Object as newInstance takes varargs */);
+ } catch (InstantiationException
+ | IllegalAccessException
+ | InvocationTargetException
+ | NoSuchMethodException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public String description() {
+ return name() + "(numBuckets, col1, col2, ...)";
+ }
+
+ @Override
+ public String name() {
+ return "bucket";
+ }
+
+ /** Bound bucket function for generic type. */
+ public static class BucketGeneric implements ScalarFunction {
+ protected final DataType[] bucketKeyTypes;
+ protected final BinaryRow bucketKeyRow;
+ // not serializable
+ protected transient BinaryRowWriter bucketKeyWriter;
+ private transient ValueWriter[] valueWriters;
+
+ public BucketGeneric(DataType[] sqlTypes) {
+ this.bucketKeyTypes = sqlTypes;
+ this.bucketKeyRow = new BinaryRow(bucketKeyTypes.length);
+ this.bucketKeyWriter = new BinaryRowWriter(bucketKeyRow);
+ this.valueWriters = createValueWriter(bucketKeyTypes);
+ }
+
+ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
+ in.defaultReadObject();
+ this.bucketKeyWriter = new BinaryRowWriter(bucketKeyRow);
+ this.valueWriters = createValueWriter(bucketKeyTypes);
+ }
+
+ private static ValueWriter[] createValueWriter(DataType[] columnTypes) {
+ ValueWriter[] writers = new ValueWriter[columnTypes.length];
+ for (int i = 0; i < columnTypes.length; i += 1) {
+ writers[i] = ValueWriter.of(columnTypes[i]);
+ }
+
+ return writers;
+ }
+
+ @Override
+ public Integer produceResult(InternalRow input) {
+ bucketKeyWriter.reset();
+ for (int i = 0; i < valueWriters.length; i += 1) {
+ valueWriters[i].write(bucketKeyWriter, i, input, i + 1);
+ }
+
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, input.getInt(NUM_BUCKETS_ORDINAL));
+ }
+
+ @Override
+ public DataType[] inputTypes() {
+ DataType[] inputTypes = new DataType[bucketKeyTypes.length + 1];
+ inputTypes[0] = DataTypes.IntegerType;
+ for (int i = 0; i < bucketKeyTypes.length; i += 1) {
+ inputTypes[i + 1] = bucketKeyTypes[i];
+ }
+
+ return inputTypes;
+ }
+
+ @Override
+ public DataType resultType() {
+ return DataTypes.IntegerType;
+ }
+
+ @Override
+ public boolean isResultNullable() {
+ return false;
+ }
+
+ @Override
+ public String name() {
+ return "bucket";
+ }
+
+ @Override
+ public String canonicalName() {
+ return String.format(
+ "paimon.bucket(%s)",
+ Arrays.stream(bucketKeyTypes)
+ .map(DataType::catalogString)
+ .collect(Collectors.joining(", ")));
+ }
+
+ // org.apache.paimon.table.sink.KeyAndBucketExtractor.bucket(int, int)
+ protected static int bucket(BinaryRow bucketKey, int numBuckets) {
+ return Math.abs(bucketKey.hashCode() % numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Boolean} type. */
+ public static class BucketBoolean extends BucketGeneric {
+
+ public BucketBoolean(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, boolean value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeBoolean(0, value0);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Byte} type. */
+ public static class BucketByte extends BucketGeneric {
+
+ public BucketByte(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, byte value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeByte(0, value0);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Short} type. */
+ public static class BucketShort extends BucketGeneric {
+
+ public BucketShort(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, short value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeShort(0, value0);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Integer} type. */
+ public static class BucketInteger extends BucketGeneric {
+
+ public BucketInteger(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, int value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeInt(0, value0);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Long} type. */
+ public static class BucketLong extends BucketGeneric {
+
+ public BucketLong(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, long value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeLong(0, value0);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Float} type. */
+ public static class BucketFloat extends BucketGeneric {
+
+ public BucketFloat(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, float value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeFloat(0, value0);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Double} type. */
+ public static class BucketDouble extends BucketGeneric {
+
+ public BucketDouble(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, double value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeDouble(0, value0);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {String} type. */
+ public static class BucketString extends BucketGeneric {
+
+ public BucketString(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, UTF8String value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeString(0, BinaryString.fromBytes(value0.getBytes()));
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Decimal} type. */
+ public static class BucketDecimal extends BucketGeneric {
+ private final int precision;
+ private final int scale;
+ private final Function converter;
+
+ public BucketDecimal(DataType[] sqlTypes) {
+ super(sqlTypes);
+ DecimalType decimalType = (DecimalType) sqlTypes[0];
+ this.precision = decimalType.precision();
+ this.scale = decimalType.scale();
+ this.converter = SparkConversions.decimalSparkToPaimon(precision, scale);
+ }
+
+ public int invoke(int numBuckets, Decimal value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeDecimal(0, converter.apply(value0), precision);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for paimon {@code TimestampType} type with precision 6. */
+ public static class BucketTimestamp extends BucketGeneric {
+
+ public BucketTimestamp(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, long value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeTimestamp(
+ 0, Timestamp.fromMicros(value0), SPARK_TIMESTAMP_PRECISION);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bound bucket function for {Binary} type. */
+ public static class BucketBinary extends BucketGeneric {
+
+ public BucketBinary(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, byte[] value0) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeBinary(0, value0);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {Integer, Integer} composite type. */
+ public static class BucketIntegerInteger extends BucketGeneric {
+
+ public BucketIntegerInteger(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, int value0, int value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeInt(0, value0);
+ bucketKeyWriter.writeInt(1, value1);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {Integer, Long} composite type. */
+ public static class BucketIntegerLong extends BucketGeneric {
+
+ public BucketIntegerLong(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, int value0, long value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeInt(0, value0);
+ bucketKeyWriter.writeLong(0, value1);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {Integer, String} composite type. */
+ public static class BucketIntegerString extends BucketGeneric {
+
+ public BucketIntegerString(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, int value0, UTF8String value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeInt(0, value0);
+ bucketKeyWriter.writeString(1, BinaryString.fromBytes(value1.getBytes()));
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {Long, Long} composite type. */
+ public static class BucketLongLong extends BucketGeneric {
+
+ public BucketLongLong(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, long value0, long value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeLong(0, value0);
+ bucketKeyWriter.writeLong(1, value1);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {Long, Integer} composite type. */
+ public static class BucketLongInteger extends BucketGeneric {
+
+ public BucketLongInteger(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, long value0, int value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeLong(0, value0);
+ bucketKeyWriter.writeInt(0, value1);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {Long, String} composite type. */
+ public static class BucketLongString extends BucketGeneric {
+
+ public BucketLongString(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, long value0, UTF8String value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeLong(0, value0);
+ bucketKeyWriter.writeString(1, BinaryString.fromBytes(value1.getBytes()));
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {String, Integer} composite type. */
+ public static class BucketStringInteger extends BucketGeneric {
+
+ public BucketStringInteger(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, UTF8String value0, int value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeString(0, BinaryString.fromBytes(value0.getBytes()));
+ bucketKeyWriter.writeInt(1, value1);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {String, Long} composite type. */
+ public static class BucketStringLong extends BucketGeneric {
+
+ public BucketStringLong(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, UTF8String value0, long value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeString(0, BinaryString.fromBytes(value0.getBytes()));
+ bucketKeyWriter.writeLong(1, value1);
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ /** Bucket function for {String, String} composite type. */
+ public static class BucketStringString extends BucketGeneric {
+
+ public BucketStringString(DataType[] sqlTypes) {
+ super(sqlTypes);
+ }
+
+ public int invoke(int numBuckets, UTF8String value0, UTF8String value1) {
+ bucketKeyWriter.reset();
+ bucketKeyWriter.writeString(0, BinaryString.fromBytes(value0.getBytes()));
+ bucketKeyWriter.writeString(1, BinaryString.fromBytes(value1.getBytes()));
+ bucketKeyWriter.complete();
+ return bucket(bucketKeyRow, numBuckets);
+ }
+ }
+
+ @FunctionalInterface
+ interface ValueWriter {
+ void write(BinaryRowWriter writer, int writePos, InternalRow srcRow, int srcPos);
+
+ static ValueWriter of(DataType type) {
+ if (type instanceof BooleanType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeBoolean(writePos, srcRow.getBoolean(srcPos));
+ }
+ };
+
+ } else if (type instanceof ByteType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeByte(writePos, srcRow.getByte(srcPos));
+ }
+ };
+
+ } else if (type instanceof ShortType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeShort(writePos, srcRow.getShort(srcPos));
+ }
+ };
+
+ } else if (type instanceof IntegerType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeInt(writePos, srcRow.getInt(srcPos));
+ }
+ };
+
+ } else if (type instanceof LongType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeLong(writePos, srcRow.getLong(srcPos));
+ }
+ };
+
+ } else if (type instanceof FloatType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeFloat(writePos, srcRow.getFloat(srcPos));
+ }
+ };
+
+ } else if (type instanceof DoubleType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeDouble(writePos, srcRow.getDouble(srcPos));
+ }
+ };
+
+ } else if (type instanceof StringType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeString(
+ writePos,
+ BinaryString.fromBytes(srcRow.getUTF8String(srcPos).getBytes()));
+ }
+ };
+
+ } else if (type instanceof BinaryType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ writer.setNullAt(writePos);
+ } else {
+ writer.writeBinary(writePos, srcRow.getBinary(srcPos));
+ }
+ };
+
+ } else if (type instanceof DecimalType) {
+ DecimalType decimalType = (DecimalType) type;
+ int precision = decimalType.precision();
+ int scale = decimalType.scale();
+ boolean compact = org.apache.paimon.data.Decimal.isCompact(precision);
+ Function converter =
+ SparkConversions.decimalSparkToPaimon(precision, scale);
+
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ // org.apache.paimon.codegen.GenerateUtils.binaryWriterWriteNull
+ if (!compact) {
+ writer.writeDecimal(writePos, null, precision);
+ } else {
+ writer.setNullAt(writePos);
+ }
+ } else {
+ Decimal decimal = srcRow.getDecimal(srcPos, precision, scale);
+ writer.writeDecimal(writePos, converter.apply(decimal), precision);
+ }
+ };
+
+ } else if (type instanceof TimestampType || type instanceof TimestampNTZType) {
+ return (writer, writePos, srcRow, srcPos) -> {
+ if (srcRow.isNullAt(srcPos)) {
+ // org.apache.paimon.codegen.GenerateUtils.binaryWriterWriteNull
+ // must not be compacted as only Spark default precision 6 should be allowed
+ writer.writeTimestamp(writePos, null, SPARK_TIMESTAMP_PRECISION);
+ } else {
+ writer.writeTimestamp(
+ writePos,
+ Timestamp.fromMicros(srcRow.getLong(srcPos)),
+ SPARK_TIMESTAMP_PRECISION);
+ }
+ };
+
+ } else {
+ throw new UnsupportedOperationException("Unsupported type: " + type.simpleString());
+ }
+ }
+ }
+}
diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java
index c7949e11948f..b265419b535b 100644
--- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java
+++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java
@@ -35,7 +35,6 @@
import java.util.Map;
import static org.apache.paimon.utils.Preconditions.checkArgument;
-import static org.apache.spark.sql.types.DataTypes.IntegerType;
import static org.apache.spark.sql.types.DataTypes.StringType;
/** Paimon functions. */
@@ -57,60 +56,6 @@ public static UnboundFunction load(String name) {
return FUNCTIONS.get(name);
}
- /**
- * For now, we only support report bucket partitioning for table scan. So the case `SELECT
- * bucket(10, col)` would fail since we do not implement {@link
- * org.apache.spark.sql.connector.catalog.functions.ScalarFunction}
- */
- public static class BucketFunction implements UnboundFunction {
- @Override
- public BoundFunction bind(StructType inputType) {
- if (inputType.size() != 2) {
- throw new UnsupportedOperationException(
- "Wrong number of inputs (expected numBuckets and value)");
- }
-
- StructField numBucket = inputType.fields()[0];
- StructField bucketField = inputType.fields()[1];
- checkArgument(
- numBucket.dataType() == IntegerType,
- "bucket number field must be integer type");
-
- return new BoundFunction() {
- @Override
- public DataType[] inputTypes() {
- return new DataType[] {IntegerType, bucketField.dataType()};
- }
-
- @Override
- public DataType resultType() {
- return IntegerType;
- }
-
- @Override
- public String name() {
- return "bucket";
- }
-
- @Override
- public String canonicalName() {
- // We have to override this method to make it support canonical equivalent
- return "paimon.bucket(" + bucketField.dataType().catalogString() + ", int)";
- }
- };
- }
-
- @Override
- public String description() {
- return name();
- }
-
- @Override
- public String name() {
- return "bucket";
- }
- }
-
/**
* For partitioned tables, this function returns the maximum value of the first level partition
* of the partitioned table, sorted alphabetically. Note, empty partitions will be skipped. For
diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/source/SparkV2Write.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/source/SparkV2Write.java
new file mode 100644
index 000000000000..82b43b8d6557
--- /dev/null
+++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/source/SparkV2Write.java
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.source;
+
+import org.apache.paimon.disk.IOManager;
+import org.apache.paimon.spark.SparkInternalRowWrapper;
+import org.apache.paimon.spark.SparkUtils;
+import org.apache.paimon.spark.SparkWriteRequirement;
+import org.apache.paimon.table.FileStoreTable;
+import org.apache.paimon.table.sink.BatchTableCommit;
+import org.apache.paimon.table.sink.BatchTableWrite;
+import org.apache.paimon.table.sink.BatchWriteBuilder;
+import org.apache.paimon.table.sink.CommitMessage;
+import org.apache.paimon.table.sink.CommitMessageSerializer;
+import org.apache.paimon.types.RowType;
+import org.apache.paimon.utils.Preconditions;
+
+import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap;
+import org.apache.paimon.shade.guava30.com.google.common.collect.Lists;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.distributions.Distribution;
+import org.apache.spark.sql.connector.expressions.SortOrder;
+import org.apache.spark.sql.connector.write.BatchWrite;
+import org.apache.spark.sql.connector.write.DataWriter;
+import org.apache.spark.sql.connector.write.DataWriterFactory;
+import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
+import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering;
+import org.apache.spark.sql.connector.write.Write;
+import org.apache.spark.sql.connector.write.WriterCommitMessage;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.paimon.CoreOptions.DYNAMIC_PARTITION_OVERWRITE;
+
+/** Spark V2WRite. */
+public class SparkV2Write implements Write, RequiresDistributionAndOrdering {
+ private static final Logger LOG = LoggerFactory.getLogger(SparkV2Write.class);
+
+ private final FileStoreTable table;
+ private final BatchWriteBuilder writeBuilder;
+ private final SparkWriteRequirement writeRequirement;
+ private final boolean overwriteDynamic;
+ private final Map overwritePartitions;
+
+ public SparkV2Write(
+ FileStoreTable newTable,
+ boolean overwriteDynamic,
+ Map overwritePartitions) {
+ Preconditions.checkArgument(
+ !overwriteDynamic || overwritePartitions == null,
+ "Cannot overwrite dynamically and by filter both");
+
+ this.table =
+ newTable.copy(
+ ImmutableMap.of(
+ DYNAMIC_PARTITION_OVERWRITE.key(),
+ Boolean.toString(overwriteDynamic)));
+ this.writeBuilder = table.newBatchWriteBuilder();
+ if (overwritePartitions != null) {
+ writeBuilder.withOverwrite(overwritePartitions);
+ }
+
+ this.writeRequirement = SparkWriteRequirement.of(table);
+ this.overwriteDynamic = overwriteDynamic;
+ this.overwritePartitions = overwritePartitions;
+ }
+
+ @Override
+ public Distribution requiredDistribution() {
+ Distribution distribution = writeRequirement.distribution();
+ LOG.info("Requesting {} as write distribution for table {}", distribution, table.name());
+ return distribution;
+ }
+
+ @Override
+ public SortOrder[] requiredOrdering() {
+ SortOrder[] ordering = writeRequirement.ordering();
+ LOG.info("Requesting {} as write ordering for table {}", ordering, table.name());
+ return ordering;
+ }
+
+ @Override
+ public BatchWrite toBatch() {
+ return new PaimonBatchWrite();
+ }
+
+ private class PaimonBatchWrite implements BatchWrite {
+
+ @Override
+ public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
+ return new WriterFactory(table.rowType(), writeBuilder);
+ }
+
+ @Override
+ public boolean useCommitCoordinator() {
+ return false;
+ }
+
+ @Override
+ public void commit(WriterCommitMessage[] messages) {
+ LOG.info("Committing to table {}, ", table.name());
+ BatchTableCommit batchTableCommit = writeBuilder.newCommit();
+ List allCommitMessage = Lists.newArrayList();
+
+ for (WriterCommitMessage message : messages) {
+ if (message != null) {
+ List commitMessages = ((TaskCommit) message).commitMessages();
+ allCommitMessage.addAll(commitMessages);
+ }
+ }
+
+ try {
+ long start = System.currentTimeMillis();
+ batchTableCommit.commit(allCommitMessage);
+ LOG.info("Committed in {} ms", System.currentTimeMillis() - start);
+ batchTableCommit.close();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void abort(WriterCommitMessage[] messages) {
+ // TODO abort
+ }
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "PaimonWrite(table=%s, %s)",
+ table.name(),
+ overwriteDynamic
+ ? "overwriteDynamic=true"
+ : String.format("overwritePartitions=%s", overwritePartitions));
+ }
+
+ private static class WriterFactory implements DataWriterFactory {
+ private final RowType rowType;
+ private final BatchWriteBuilder batchWriteBuilder;
+
+ WriterFactory(RowType rowType, BatchWriteBuilder batchWriteBuilder) {
+ this.rowType = rowType;
+ this.batchWriteBuilder = batchWriteBuilder;
+ }
+
+ @Override
+ public DataWriter createWriter(int partitionId, long taskId) {
+ BatchTableWrite batchTableWrite = batchWriteBuilder.newWrite();
+ return new GenericWriter(batchTableWrite, rowType);
+ }
+ }
+
+ private static class GenericWriter implements DataWriter {
+ private final BatchTableWrite batchTableWrite;
+ private final RowType rowType;
+ private final IOManager ioManager;
+
+ private GenericWriter(BatchTableWrite batchTableWrite, RowType rowType) {
+ this.batchTableWrite = batchTableWrite;
+ this.rowType = rowType;
+ this.ioManager = SparkUtils.createIOManager();
+ batchTableWrite.withIOManager(ioManager);
+ }
+
+ @Override
+ public void write(InternalRow record) throws IOException {
+ // TODO rowKind
+ SparkInternalRowWrapper wrappedInternalRow =
+ new SparkInternalRowWrapper(rowType, record);
+ try {
+ batchTableWrite.write(wrappedInternalRow);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public WriterCommitMessage commit() throws IOException {
+ try {
+ List commitMessages = batchTableWrite.prepareCommit();
+ TaskCommit taskCommit = new TaskCommit(commitMessages);
+ return taskCommit;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ close();
+ }
+ }
+
+ @Override
+ public void abort() throws IOException {
+ // TODO clean uncommitted files
+ close();
+ }
+
+ @Override
+ public void close() throws IOException {
+ try {
+ batchTableWrite.close();
+ ioManager.close();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ private static class TaskCommit implements WriterCommitMessage {
+ private final List serializedMessageList = Lists.newArrayList();
+
+ TaskCommit(List commitMessages) {
+ if (commitMessages == null || commitMessages.isEmpty()) {
+ return;
+ }
+
+ CommitMessageSerializer serializer = new CommitMessageSerializer();
+ for (CommitMessage commitMessage : commitMessages) {
+ try {
+ serializedMessageList.add(serializer.serialize(commitMessage));
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+ }
+
+ List commitMessages() {
+ CommitMessageSerializer deserializer = new CommitMessageSerializer();
+ List commitMessageList = Lists.newArrayList();
+ for (byte[] serializedMessage : serializedMessageList) {
+ try {
+ commitMessageList.add(
+ deserializer.deserialize(deserializer.getVersion(), serializedMessage));
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ return commitMessageList;
+ }
+ }
+}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/BaseWriteBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/BaseWriteBuilder.scala
new file mode 100644
index 000000000000..c6b6e961bfe0
--- /dev/null
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/BaseWriteBuilder.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark
+
+import org.apache.paimon.options.Options
+import org.apache.paimon.table.FileStoreTable
+import org.apache.paimon.types.RowType
+
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.connector.write.WriteBuilder
+import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, Not, Or}
+
+import scala.collection.JavaConverters._
+
+abstract class BaseWriteBuilder(table: FileStoreTable, options: Options)
+ extends WriteBuilder
+ with SQLConfHelper {
+
+ private def failWithReason(filter: Filter): Unit = {
+ throw new RuntimeException(
+ s"Only support Overwrite filters with Equal and EqualNullSafe, but got: $filter")
+ }
+
+ private def validateFilter(filter: Filter): Unit = filter match {
+ case And(left, right) =>
+ validateFilter(left)
+ validateFilter(right)
+ case _: Or => failWithReason(filter)
+ case _: Not => failWithReason(filter)
+ case e: EqualTo if e.references.length == 1 && !e.value.isInstanceOf[Filter] =>
+ case e: EqualNullSafe if e.references.length == 1 && !e.value.isInstanceOf[Filter] =>
+ case _: AlwaysTrue | _: AlwaysFalse =>
+ case _ => failWithReason(filter)
+ }
+
+ // `SupportsOverwrite#canOverwrite` is added since Spark 3.4.0.
+ // We do this checking by self to work with previous Spark version.
+ protected def failIfCanNotOverwrite(filters: Array[Filter]): Unit = {
+ // For now, we only support overwrite with two cases:
+ // - overwrite with partition columns to be compatible with v1 insert overwrite
+ // See [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveInsertInto#staticDeleteExpression]].
+ // - truncate-like overwrite and the filter is always true.
+ //
+ // Fast fail for other custom filters which through v2 write interface, e.g.,
+ // `dataframe.writeTo(T).overwrite(...)`
+ val partitionRowType = table.schema.logicalPartitionType()
+ val partitionNames = partitionRowType.getFieldNames.asScala
+ val allReferences = filters.flatMap(_.references)
+ val containsDataColumn = allReferences.exists {
+ reference => !partitionNames.exists(conf.resolver.apply(reference, _))
+ }
+ if (containsDataColumn) {
+ throw new RuntimeException(
+ s"Only support Overwrite filters on partition column ${partitionNames.mkString(
+ ", ")}, but got ${filters.mkString(", ")}.")
+ }
+ if (allReferences.distinct.length < allReferences.length) {
+ // fail with `part = 1 and part = 2`
+ throw new RuntimeException(
+ s"Only support Overwrite with one filter for each partition column, but got ${filters.mkString(", ")}.")
+ }
+ filters.foreach(validateFilter)
+ }
+
+ private def parseSaveMode(
+ saveMode: SaveMode,
+ table: FileStoreTable): (Boolean, Map[String, String]) = {
+ var dynamicPartitionOverwriteMode = false
+ val overwritePartition = saveMode match {
+ case InsertInto => null
+ case Overwrite(filter) =>
+ if (filter.isEmpty) {
+ Map.empty[String, String]
+ } else if (isTruncate(filter.get)) {
+ Map.empty[String, String]
+ } else {
+ convertPartitionFilterToMap(filter.get, table.schema.logicalPartitionType())
+ }
+ case DynamicOverWrite =>
+ dynamicPartitionOverwriteMode = true
+ Map.empty[String, String]
+ case _ =>
+ throw new UnsupportedOperationException(s" This mode is unsupported for now.")
+ }
+ (dynamicPartitionOverwriteMode, overwritePartition)
+ }
+
+ /**
+ * For the 'INSERT OVERWRITE' semantics of SQL, Spark DataSourceV2 will call the `truncate`
+ * methods where the `AlwaysTrue` Filter is used.
+ */
+ def isTruncate(filter: Filter): Boolean = {
+ val filters = splitConjunctiveFilters(filter)
+ filters.length == 1 && filters.head.isInstanceOf[AlwaysTrue]
+ }
+
+ /** See [[ org.apache.paimon.spark.SparkWriteBuilder#failIfCanNotOverwrite]] */
+ def convertPartitionFilterToMap(
+ filter: Filter,
+ partitionRowType: RowType): Map[String, String] = {
+ // todo: replace it with SparkV2FilterConverter when we drop Spark3.2
+ val converter = new SparkFilterConverter(partitionRowType)
+ splitConjunctiveFilters(filter).map {
+ case EqualNullSafe(attribute, value) =>
+ (attribute, converter.convertString(attribute, value))
+ case EqualTo(attribute, value) =>
+ (attribute, converter.convertString(attribute, value))
+ case _ =>
+ // Should not happen
+ throw new RuntimeException(
+ s"Only support Overwrite filters with Equal and EqualNullSafe, but got: $filter")
+ }.toMap
+ }
+
+ private def splitConjunctiveFilters(filter: Filter): Seq[Filter] = {
+ filter match {
+ case And(filter1, filter2) =>
+ splitConjunctiveFilters(filter1) ++ splitConjunctiveFilters(filter2)
+ case other => other :: Nil
+ }
+ }
+}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkSQLProperties.java b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkSQLProperties.java
new file mode 100644
index 000000000000..2101c95f2817
--- /dev/null
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkSQLProperties.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark;
+
+/**
+ * Spark SQL properties for Paimon.
+ */
+public class SparkSQLProperties {
+ private SparkSQLProperties(){}
+
+ public static final String USE_V2_WRITE = "spark.sql.paimon.use-v2-write";
+ public static final String USE_V2_WRITE_DEFAULT = "false";
+}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
index b9a90d8b5bef..f6e6a5a51a9a 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
@@ -20,10 +20,12 @@ package org.apache.paimon.spark
import org.apache.paimon.CoreOptions
import org.apache.paimon.options.Options
+import org.apache.paimon.spark.catalog.functions.BucketFunction
import org.apache.paimon.spark.schema.PaimonMetadataColumn
-import org.apache.paimon.table.{DataTable, FileStoreTable, KnownSplitsTable, Table}
+import org.apache.paimon.table.{BucketMode, DataTable, FileStoreTable, KnownSplitsTable, Table}
import org.apache.paimon.utils.StringUtils
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, SupportsWrite, TableCapability, TableCatalog}
import org.apache.spark.sql.connector.expressions.{Expressions, Transform}
import org.apache.spark.sql.connector.read.ScanBuilder
@@ -43,6 +45,27 @@ case class SparkTable(table: Table)
with SupportsMetadataColumns
with PaimonPartitionManagement {
+ private lazy val sparkSession: SparkSession = SparkSession.active
+ private lazy val useV2Write: Boolean = {
+ val v2WriteConfigured = sparkSession.conf
+ .get(SparkSQLProperties.USE_V2_WRITE, SparkSQLProperties.USE_V2_WRITE_DEFAULT)
+ .toBoolean
+ v2WriteConfigured && supportsV2Write
+ }
+
+ private def supportsV2Write: Boolean = {
+ table match {
+ case storeTable: FileStoreTable =>
+ storeTable.bucketMode() match {
+ case BucketMode.HASH_FIXED => BucketFunction.supportsTable(storeTable)
+ case BucketMode.BUCKET_UNAWARE => true
+ case _ => false
+ }
+
+ case _ => false
+ }
+ }
+
def getTable: Table = table
override def name: String = table.fullName
@@ -73,14 +96,18 @@ case class SparkTable(table: Table)
}
override def capabilities: JSet[TableCapability] = {
- JEnumSet.of(
+ val capabilities = JEnumSet.of(
TableCapability.ACCEPT_ANY_SCHEMA,
TableCapability.BATCH_READ,
- TableCapability.V1_BATCH_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
TableCapability.OVERWRITE_DYNAMIC,
TableCapability.MICRO_BATCH_READ
)
+
+ capabilities.add(
+ if (useV2Write) TableCapability.BATCH_WRITE else TableCapability.V1_BATCH_WRITE)
+
+ capabilities
}
override def metadataColumns: Array[MetadataColumn] = {
@@ -105,7 +132,12 @@ case class SparkTable(table: Table)
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
table match {
case fileStoreTable: FileStoreTable =>
- new SparkWriteBuilder(fileStoreTable, Options.fromMap(info.options))
+ val options = Options.fromMap(info.options)
+ if (useV2Write) {
+ new SparkV2WriteBuilder(fileStoreTable, options)
+ } else {
+ new SparkWriteBuilder(fileStoreTable, options)
+ }
case _ =>
throw new RuntimeException("Only FileStoreTable can be written.")
}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2WriteBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2WriteBuilder.scala
new file mode 100644
index 000000000000..377992a34720
--- /dev/null
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2WriteBuilder.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark
+
+import org.apache.paimon.options.Options
+import org.apache.paimon.spark.source.SparkV2Write
+import org.apache.paimon.table.FileStoreTable
+
+import org.apache.spark.sql.connector.write.{SupportsDynamicOverwrite, SupportsOverwrite, WriteBuilder}
+import org.apache.spark.sql.sources.{And, Filter}
+
+import scala.collection.JavaConverters._
+
+private class SparkV2WriteBuilder(table: FileStoreTable, options: Options)
+ extends BaseWriteBuilder(table, options)
+ with SupportsOverwrite
+ with SupportsDynamicOverwrite {
+
+ private var overwriteDynamic = false
+ private var overwritePartitions: Map[String, String] = null
+ override def build = new SparkV2Write(table, overwriteDynamic, overwritePartitions.asJava)
+ override def overwrite(filters: Array[Filter]): WriteBuilder = {
+ if (overwriteDynamic) {
+ throw new IllegalArgumentException("Cannot overwrite dynamically and by filter both")
+ }
+
+ failIfCanNotOverwrite(filters)
+
+ val conjunctiveFilters = if (filters.nonEmpty) {
+ Some(filters.reduce((l, r) => And(l, r)))
+ } else {
+ None
+ }
+
+ if (isTruncate(conjunctiveFilters.get)) {
+ overwritePartitions = Map.empty[String, String]
+ } else {
+ overwritePartitions =
+ convertPartitionFilterToMap(conjunctiveFilters.get, table.schema.logicalPartitionType())
+ }
+
+ this
+ }
+
+ override def overwriteDynamicPartitions(): WriteBuilder = {
+ if (overwritePartitions != null) {
+ throw new IllegalArgumentException("Cannot overwrite dynamically and by filter both")
+ }
+
+ overwriteDynamic = true
+ this
+ }
+}
diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
index fff94ce0374d..6006a93181c0 100644
--- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
+++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java
@@ -26,8 +26,12 @@
import org.apache.paimon.table.FileStoreTable;
import org.apache.paimon.table.FileStoreTableFactory;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.functions;
+import org.apache.spark.sql.internal.SQLConf;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
@@ -42,6 +46,7 @@
import java.util.List;
import java.util.stream.Collectors;
+import static org.apache.spark.sql.functions.expr;
import static org.assertj.core.api.Assertions.assertThat;
/** ITCase for spark writer. */
@@ -81,6 +86,53 @@ public void afterEach() {
spark.sql("DROP TABLE T");
}
+ @Test
+ public void testV2Write() {
+ spark.sql(
+ "CREATE TABLE T (a INT, b INT, c STRING) partitioned by (a) TBLPROPERTIES"
+ + " ('primary-key'='a,b', 'bucket'='4', 'file.format'='parquet')");
+
+ spark.conf().set(SparkSQLProperties.USE_V2_WRITE, true);
+
+ spark.sql("INSERT INTO T VALUES(1, 1, 'a'), (2, 2, 'b')");
+ spark.sql("INSERT INTO T VALUES (1, 1, 'A')");
+
+ List rows = spark.sql("SELECT * FROM T").collectAsList();
+ assertThat(rows.toString()).isEqualTo("[[1,1,A], [2,2,b]]");
+ }
+
+ @Test
+ public void testStaticV2Overwrite() {
+ spark.sql(
+ "CREATE TABLE T (a INT, b INT, c STRING) partitioned by (a) TBLPROPERTIES"
+ + " ('primary-key'='a,b', 'bucket'='4', 'file.format'='parquet')");
+
+ spark.conf().set(SparkSQLProperties.USE_V2_WRITE, true);
+ spark.conf().set(SQLConf.PARTITION_OVERWRITE_MODE(), "STATIC");
+
+ spark.sql("INSERT INTO T VALUES(1, 1, 'a'), (2, 2, 'b')");
+ spark.sql("INSERT OVERWRITE T VALUES (1, 1, 'A')");
+
+ List rows = spark.sql("SELECT * FROM T").collectAsList();
+ assertThat(rows.toString()).isEqualTo("[[1,1,A]]");
+ }
+
+ @Test
+ public void testDynamicV2Overwrite() {
+ spark.sql(
+ "CREATE TABLE T (a INT, b INT, c STRING) partitioned by (a) TBLPROPERTIES"
+ + " ('primary-key'='a,b', 'bucket'='4', 'file.format'='parquet')");
+
+ spark.conf().set(SparkSQLProperties.USE_V2_WRITE, true);
+ spark.conf().set(SQLConf.PARTITION_OVERWRITE_MODE(), "DYNAMIC");
+
+ spark.sql("INSERT INTO T VALUES(1, 1, 'a'), (2, 2, 'b')");
+ spark.sql("INSERT OVERWRITE T VALUES (1, 1, 'A')");
+
+ List rows = spark.sql("SELECT * FROM T").collectAsList();
+ assertThat(rows.toString()).isEqualTo("[[2,2,b], [1,1,A]]");
+ }
+
@Test
public void testWrite() {
spark.sql(
@@ -89,6 +141,55 @@ public void testWrite() {
innerSimpleWrite();
}
+ @Test
+ public void testWrite1() throws NoSuchTableException {
+ spark.sql(
+ "CREATE TABLE T (a bigint, b INT, c STRING) TBLPROPERTIES"
+ + " ('primary-key'='a', 'bucket'='4', 'file.format'='parquet')");
+ Dataset df =
+ spark.range(10)
+ .withColumnRenamed("id", "a")
+ .withColumn("b", functions.expr("CAST(a as INT)"))
+ .withColumn("c", functions.expr("CAST(a as STRING)"))
+ .coalesce(1);
+
+ df.writeTo("T").append();
+ }
+
+ @Test
+ public void testWrite2() throws NoSuchTableException {
+
+ spark.sql(
+ String.format(
+ "CREATE TABLE %s ("
+ + "intCol INT NOT NULL, "
+ + "longCol BIGINT NOT NULL, "
+ + "floatCol FLOAT NOT NULL, "
+ + "doubleCol DOUBLE NOT NULL, "
+ + "decimalCol DECIMAL(20, 5) NOT NULL, "
+ + "stringCol1 STRING NOT NULL, "
+ + "stringCol2 STRING NOT NULL, "
+ + "stringCol3 STRING NOT NULL"
+ + ") using paimon "
+ + "TBLPROPERTIES('primary-key'='stringCol1,stringCol2', 'bucket'='4')",
+ "T"));
+
+ benchmarkData().writeTo("T").append();
+ }
+
+ private Dataset benchmarkData() {
+ return spark.range(1)
+ .withColumnRenamed("id", "longCol")
+ .withColumn("intCol", expr("CAST(longCol AS INT)"))
+ .withColumn("floatCol", expr("CAST(longCol AS FLOAT)"))
+ .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)"))
+ .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))"))
+ .withColumn("stringCol1", expr("CAST(longCol AS STRING)"))
+ .withColumn("stringCol2", expr("CAST(longCol AS STRING)"))
+ .withColumn("stringCol3", expr("CAST(longCol AS STRING)"))
+ .coalesce(1);
+ }
+
@Test
public void testWritePartitionTable() {
spark.sql(
diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java
new file mode 100644
index 000000000000..0ad77337a44f
--- /dev/null
+++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.catalog.functions;
+
+import org.apache.paimon.CoreOptions;
+import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.Decimal;
+import org.apache.paimon.data.GenericRow;
+import org.apache.paimon.data.InternalRow;
+import org.apache.paimon.data.Timestamp;
+import org.apache.paimon.schema.TableSchema;
+import org.apache.paimon.spark.SparkTypeUtils;
+import org.apache.paimon.table.sink.FixedBucketRowKeyExtractor;
+import org.apache.paimon.types.BigIntType;
+import org.apache.paimon.types.BooleanType;
+import org.apache.paimon.types.DataField;
+import org.apache.paimon.types.DecimalType;
+import org.apache.paimon.types.DoubleType;
+import org.apache.paimon.types.FloatType;
+import org.apache.paimon.types.IntType;
+import org.apache.paimon.types.LocalZonedTimestampType;
+import org.apache.paimon.types.RowType;
+import org.apache.paimon.types.SmallIntType;
+import org.apache.paimon.types.TimestampType;
+import org.apache.paimon.types.TinyIntType;
+import org.apache.paimon.types.VarBinaryType;
+import org.apache.paimon.types.VarCharType;
+
+import org.apache.paimon.shade.guava30.com.google.common.base.Joiner;
+import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap;
+
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.types.UTF8String;
+import org.junit.jupiter.api.Test;
+
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for Spark bucket functions. */
+public class BucketFunctionTest {
+ private static final int NUM_BUCKETS =
+ ThreadLocalRandom.current().nextInt(1, Integer.MAX_VALUE);
+
+ private static final String BOOLEAN_COL = "boolean_col";
+ private static final String BYTE_COL = "byte_col";
+ private static final String SHORT_COL = "short_col";
+ private static final String INTEGER_COL = "integer_col";
+ private static final String LONG_COL = "long_col";
+ private static final String FLOAT_COL = "float_col";
+ private static final String DOUBLE_COL = "double_col";
+ private static final String STRING_COL = "string_col";
+ private static final String DECIMAL_COL = "decimal_col";
+ private static final String COMPACTED_DECIMAL_COL = "compacted_decimal_col";
+ private static final String TIMESTAMP_COL = "timestamp_col";
+ private static final String LZ_TIMESTAMP_COL = "lz_timestamp_col";
+ private static final String BINARY_COL = "binary_col";
+
+ private static final int DECIMAL_PRECISION = 38;
+ private static final int DECIMAL_SCALE = 18;
+ private static final int COMPACTED_DECIMAL_PRECISION = 18;
+ private static final int COMPACTED_DECIMAL_SCALE = 9;
+ private static final int TIMESTAMP_PRECISION = 6;
+
+ private static final RowType ROW_TYPE =
+ new RowType(
+ Arrays.asList(
+ new DataField(0, BOOLEAN_COL, new BooleanType()),
+ new DataField(1, BYTE_COL, new TinyIntType()),
+ new DataField(2, SHORT_COL, new SmallIntType()),
+ new DataField(3, INTEGER_COL, new IntType()),
+ new DataField(4, LONG_COL, new BigIntType()),
+ new DataField(5, FLOAT_COL, new FloatType()),
+ new DataField(6, DOUBLE_COL, new DoubleType()),
+ new DataField(7, STRING_COL, new VarCharType(VarCharType.MAX_LENGTH)),
+ new DataField(
+ 8,
+ DECIMAL_COL,
+ new DecimalType(DECIMAL_PRECISION, DECIMAL_SCALE)),
+ new DataField(
+ 9,
+ COMPACTED_DECIMAL_COL,
+ new DecimalType(
+ COMPACTED_DECIMAL_PRECISION, COMPACTED_DECIMAL_SCALE)),
+ new DataField(
+ 10, TIMESTAMP_COL, new TimestampType(TIMESTAMP_PRECISION)),
+ new DataField(
+ 11,
+ LZ_TIMESTAMP_COL,
+ new LocalZonedTimestampType(TIMESTAMP_PRECISION)),
+ new DataField(
+ 12, BINARY_COL, new VarBinaryType(VarBinaryType.MAX_LENGTH))));
+
+ private static final StructType SPARK_TYPE = SparkTypeUtils.fromPaimonRowType(ROW_TYPE);
+
+ private static InternalRow randomPaimonInternalRow() {
+ Random random = new Random();
+ BigInteger unscaled = new BigInteger(String.valueOf(random.nextInt()));
+ BigDecimal bigDecimal1 = new BigDecimal(unscaled, DECIMAL_SCALE);
+ BigDecimal bigDecimal2 = new BigDecimal(unscaled, COMPACTED_DECIMAL_SCALE);
+
+ return GenericRow.of(
+ random.nextBoolean(),
+ (byte) random.nextInt(),
+ (short) random.nextInt(),
+ random.nextInt(),
+ random.nextLong(),
+ random.nextFloat(),
+ random.nextDouble(),
+ BinaryString.fromString(UUID.randomUUID().toString()),
+ Decimal.fromBigDecimal(bigDecimal1, DECIMAL_PRECISION, DECIMAL_SCALE),
+ Decimal.fromBigDecimal(
+ bigDecimal2, COMPACTED_DECIMAL_PRECISION, COMPACTED_DECIMAL_SCALE),
+ Timestamp.now(),
+ Timestamp.now(),
+ UUID.randomUUID().toString().getBytes());
+ }
+
+ private static final InternalRow NULL_PAIMON_ROW =
+ GenericRow.of(
+ null, null, null, null, null, null, null, null, null, null, null, null, null);
+
+ @Test
+ public void testBooleanType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {BOOLEAN_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testByteType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {BYTE_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testShortType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {SHORT_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testIntegerType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {INTEGER_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testLongType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {LONG_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testFloatType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {FLOAT_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testDoubleType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {DOUBLE_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testStringType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {STRING_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testBinaryType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {BINARY_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testTimestampType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {TIMESTAMP_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testLZTimestampType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {LZ_TIMESTAMP_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testDecimalType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {DECIMAL_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testCompactedDecimalType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+ String[] bucketColumns = {COMPACTED_DECIMAL_COL};
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ @Test
+ public void testGenericType() {
+ InternalRow internalRow = randomPaimonInternalRow();
+
+ List allColumns = ROW_TYPE.getFieldNames();
+ // the order of columns matters
+ Collections.shuffle(allColumns);
+ String[] bucketColumns = null;
+ while (bucketColumns == null || bucketColumns.length < 2) {
+ bucketColumns =
+ allColumns.stream()
+ .filter(e -> ThreadLocalRandom.current().nextBoolean())
+ .toArray(String[]::new);
+ }
+
+ assertBucketEquals(internalRow, bucketColumns);
+ assertBucketEquals(NULL_PAIMON_ROW, bucketColumns);
+ }
+
+ private static void assertBucketEquals(InternalRow paimonRow, String... bucketColumns) {
+ assertThat(bucketOfSparkFunction(paimonRow, bucketColumns))
+ .as(
+ String.format(
+ "The bucket computed by spark function should be identical to the bucket computed by paimon bucket extractor function, row: %s",
+ paimonRow))
+ .isEqualTo(bucketOfPaimonFunction(paimonRow, bucketColumns));
+ }
+
+ private static Object[] columnSparkValues(InternalRow paimonRow, String... columns) {
+ Object[] values = new Object[columns.length];
+
+ for (int i = 0; i < columns.length; i += 1) {
+ String column = columns[i];
+ org.apache.paimon.types.DataType paimonType = ROW_TYPE.getField(column).type();
+ int fieldIndex = ROW_TYPE.getFieldIndex(column);
+
+ if (paimonRow.isNullAt(fieldIndex)) {
+ values[i] = null;
+ continue;
+ }
+
+ if (paimonType instanceof BooleanType) {
+ values[i] = paimonRow.getBoolean(fieldIndex);
+ } else if (paimonType instanceof TinyIntType) {
+ values[i] = paimonRow.getByte(fieldIndex);
+ } else if (paimonType instanceof SmallIntType) {
+ values[i] = paimonRow.getShort(fieldIndex);
+ } else if (paimonType instanceof IntType) {
+ values[i] = paimonRow.getInt(fieldIndex);
+ } else if (paimonType instanceof BigIntType) {
+ values[i] = paimonRow.getLong(fieldIndex);
+ } else if (paimonType instanceof FloatType) {
+ values[i] = paimonRow.getFloat(fieldIndex);
+ } else if (paimonType instanceof DoubleType) {
+ values[i] = paimonRow.getDouble(fieldIndex);
+ } else if (paimonType instanceof VarCharType) {
+ values[i] = UTF8String.fromBytes(paimonRow.getString(fieldIndex).toBytes());
+ } else if (paimonType instanceof TimestampType
+ || paimonType instanceof LocalZonedTimestampType) {
+ values[i] = paimonRow.getTimestamp(fieldIndex, 9).toMicros();
+ } else if (paimonType instanceof DecimalType) {
+ int precision = ((DecimalType) paimonType).getPrecision();
+ int scale = ((DecimalType) paimonType).getScale();
+ Decimal paimonDecimal = paimonRow.getDecimal(fieldIndex, precision, scale);
+ values[i] =
+ org.apache.spark.sql.types.Decimal.apply(
+ paimonDecimal.toBigDecimal(), precision, scale);
+ } else if (paimonType instanceof VarBinaryType) {
+ values[i] = paimonRow.getBinary(fieldIndex);
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported type: " + paimonType.asSQLString());
+ }
+ }
+
+ return values;
+ }
+
+ private static Object[] invokeParameters(Object[] sparkColumnValues) {
+ Object[] ret = new Object[sparkColumnValues.length + 1];
+ ret[0] = NUM_BUCKETS;
+ System.arraycopy(sparkColumnValues, 0, ret, 1, sparkColumnValues.length);
+ return ret;
+ }
+
+ private static StructType bucketColumnsSparkType(String... columns) {
+ StructField[] inputFields = new StructField[columns.length + 1];
+ inputFields[0] =
+ new StructField("num_buckets", DataTypes.IntegerType, false, Metadata.empty());
+
+ for (int i = 0; i < columns.length; i += 1) {
+ String column = columns[i];
+ inputFields[i + 1] = SPARK_TYPE.apply(column);
+ }
+
+ return new StructType(inputFields);
+ }
+
+ private static int bucketOfSparkFunction(InternalRow paimonRow, String... bucketColumns) {
+ BucketFunction unbound = new BucketFunction();
+ StructType inputSqlType = bucketColumnsSparkType(bucketColumns);
+ BoundFunction boundFunction = unbound.bind(inputSqlType);
+
+ Object[] inputValues = invokeParameters(columnSparkValues(paimonRow, bucketColumns));
+
+ // null values can only be handled by #produceResult
+ if (boundFunction.getClass() == BucketFunction.BucketGeneric.class
+ || Arrays.stream(inputValues).anyMatch(v -> v == null)) {
+ return ((BucketFunction.BucketGeneric) boundFunction)
+ .produceResult(new GenericInternalRow(inputValues));
+ } else {
+ Class[] parameterTypes = new Class[inputSqlType.fields().length];
+ StructField[] inputFields = inputSqlType.fields();
+ for (int i = 0; i < inputSqlType.fields().length; i += 1) {
+ DataType columnType = inputFields[i].dataType();
+ if (columnType == DataTypes.BooleanType) {
+ parameterTypes[i] = boolean.class;
+ } else if (columnType == DataTypes.ByteType) {
+ parameterTypes[i] = byte.class;
+ } else if (columnType == DataTypes.ShortType) {
+ parameterTypes[i] = short.class;
+ } else if (columnType == DataTypes.IntegerType) {
+ parameterTypes[i] = int.class;
+ } else if (columnType == DataTypes.LongType
+ || columnType == DataTypes.TimestampType
+ || columnType == DataTypes.TimestampNTZType) {
+ parameterTypes[i] = long.class;
+ } else if (columnType == DataTypes.FloatType) {
+ parameterTypes[i] = float.class;
+ } else if (columnType == DataTypes.DoubleType) {
+ parameterTypes[i] = double.class;
+ } else if (columnType == DataTypes.StringType) {
+ parameterTypes[i] = UTF8String.class;
+ } else if (columnType instanceof org.apache.spark.sql.types.DecimalType) {
+ parameterTypes[i] = org.apache.spark.sql.types.Decimal.class;
+ } else if (columnType == DataTypes.BinaryType) {
+ parameterTypes[i] = byte[].class;
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported type: " + columnType.sql());
+ }
+ }
+
+ try {
+ Method invoke = boundFunction.getClass().getMethod("invoke", parameterTypes);
+ return (int) invoke.invoke(boundFunction, inputValues);
+ } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ private static int bucketOfPaimonFunction(InternalRow internalRow, String... bucketColumns) {
+ List fields = ROW_TYPE.getFields();
+ TableSchema schema =
+ new TableSchema(
+ 0,
+ fields,
+ RowType.currentHighestFieldId(fields),
+ Collections.emptyList(),
+ Collections.emptyList(),
+ ImmutableMap.of(
+ CoreOptions.BUCKET.key(),
+ String.valueOf(NUM_BUCKETS),
+ CoreOptions.BUCKET_KEY.key(),
+ Joiner.on(",").join(bucketColumns)),
+ "");
+
+ FixedBucketRowKeyExtractor fixedBucketRowKeyExtractor =
+ new FixedBucketRowKeyExtractor(schema);
+ fixedBucketRowKeyExtractor.setRecord(internalRow);
+ return fixedBucketRowKeyExtractor.bucket();
+ }
+}