diff --git a/.gitignore b/.gitignore index 8e7d93ebaccc..f164da86f55d 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ target *.iws .java-version dependency-reduced-pom.xml +benchmark/* ### VS Code ### .vscode/ diff --git a/paimon-benchmark/paimon-spark-benchmark/pom.xml b/paimon-benchmark/paimon-spark-benchmark/pom.xml new file mode 100644 index 000000000000..2d27438512aa --- /dev/null +++ b/paimon-benchmark/paimon-spark-benchmark/pom.xml @@ -0,0 +1,162 @@ + + + + 4.0.0 + + paimon-parent + org.apache.paimon + 1.1-SNAPSHOT + + + paimon-spark-benchmark + jar + Apache Paimon Benchmarks + + + 3.5.4 + paimon-spark-benchmark + 1.37 + + + + + org.apache.paimon + paimon-spark-3.5 + ${project.version} + + + + + + + + + + + + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + + + + + org.openjdk.jmh + jmh-core + ${jmh.version} + + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + provided + + + + org.apache.logging.log4j + log4j-api + ${log4j.version} + + + + org.apache.logging.log4j + log4j-core + ${log4j.version} + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + org.apache.maven.plugins + maven-shade-plugin + 2.2 + + + package + + shade + + + ${uberjar.name} + + + org.openjdk.jmh.Main + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + + analyze-only + + + true + + + + + + + \ No newline at end of file diff --git a/paimon-benchmark/paimon-spark-benchmark/src/main/java/org.apache.paimon.spark.source/PaimonSourceWriteBenchmark.java b/paimon-benchmark/paimon-spark-benchmark/src/main/java/org.apache.paimon.spark.source/PaimonSourceWriteBenchmark.java new file mode 100644 index 000000000000..079192d84b63 --- /dev/null +++ b/paimon-benchmark/paimon-spark-benchmark/src/main/java/org.apache.paimon.spark.source/PaimonSourceWriteBenchmark.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.source; + +import org.apache.paimon.spark.SparkCatalog; +import org.apache.paimon.spark.SparkSQLProperties; + +import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap; +import org.apache.paimon.shade.guava30.com.google.common.collect.Maps; + +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.internal.SQLConf; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.File; +import java.util.Map; +import java.util.UUID; + +import static org.apache.spark.sql.functions.expr; + +/** + * Paimon source spark write benchmark. + * + *

Usage: + * + *

mvn clean install -pl ':paimon-spark-benchmark' -DskipTests=true + * + *

java -jar ./paimon-benchmark/paimon-spark-benchmark/target/paimon-spark-benchmark.jar + * org.apache.paimon.spark.source.PaimonSourceWriteBenchmark -o + * benchmark/paimon-source-write-result.txt + */ +@Fork(1) +@State(Scope.Benchmark) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.SingleShotTime) +public class PaimonSourceWriteBenchmark { + public static final String PAIMON_TABLE_NAME = "paimon_table"; + private static final int NUM_ROWS = 3_000_000; + + private final String warehousePath = warehousePath(); + + private SparkSession spark; + + protected void setup() { + spark = + SparkSession.builder() + .config("spark.ui.enabled", false) + .master("local") + .config("spark.sql.catalog.paimon", SparkCatalog.class.getName()) + .config("spark.sql.catalog.paimon.warehouse", warehousePath) + .getOrCreate(); + + spark.sql("CREATE DATABASE paimon.db"); + spark.sql("USE paimon.db"); + + spark.sql( + String.format( + "CREATE TABLE %s (" + + "intCol INT NOT NULL, " + + "longCol BIGINT NOT NULL, " + + "floatCol FLOAT NOT NULL, " + + "doubleCol DOUBLE NOT NULL, " + + "decimalCol DECIMAL(20, 5) NOT NULL, " + + "stringCol1 STRING NOT NULL, " + + "stringCol2 STRING NOT NULL, " + + "stringCol3 STRING NOT NULL" + + ") using paimon " + + "TBLPROPERTIES('primary-key'='intCol,stringCol2', 'bucket'='2')", + PAIMON_TABLE_NAME)); + } + + protected String warehousePath() { + Path warehosuePath = new Path("/tmp", "paimon-warehouse-" + UUID.randomUUID()); + return warehosuePath.toString(); + } + + @Setup + public void setupBenchmark() { + setup(); + } + + @TearDown + public void tearDown() { + File warehouseDir = new File(warehousePath); + deleteFile(warehouseDir); + } + + public static boolean deleteFile(File file) { + if (file.isDirectory()) { + File[] files = file.listFiles(); + if (files != null) { + for (File child : files) { + deleteFile(child); + } + } + } + return file.delete(); + } + + @Benchmark + @Threads(1) + public void v1Write() throws NoSuchTableException { + benchmarkData().writeTo(PAIMON_TABLE_NAME).append(); + } + + @Benchmark + @Threads(1) + public void v2Write() { + Map conf = ImmutableMap.of(SparkSQLProperties.USE_V2_WRITE, "true"); + withSQLConf( + conf, + () -> { + try { + benchmarkData().writeTo(PAIMON_TABLE_NAME).append(); + } catch (NoSuchTableException e) { + throw new RuntimeException(e); + } + }); + } + + private Dataset benchmarkData() { + return spark.range(NUM_ROWS) + .withColumn("intCol", expr("CAST(id AS INT)")) + .withColumn("longCol", expr("CAST(id AS LONG)")) + .withColumn("floatCol", expr("CAST(id AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(id AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(id AS DECIMAL(20, 5))")) + .withColumn("stringCol1", expr("CAST(id AS STRING)")) + .withColumn("stringCol2", expr("CAST(id AS STRING)")) + .withColumn("stringCol3", expr("CAST(id AS STRING)")) + .drop("id") + .coalesce(1); + } + + private void withSQLConf(Map sparkSqlConf, Action action) { + SQLConf sqlConf = SQLConf.get(); + + Map currentConfValues = Maps.newHashMap(); + sparkSqlConf + .keySet() + .forEach( + confKey -> { + if (sqlConf.contains(confKey)) { + String currentConfValue = sqlConf.getConfString(confKey); + currentConfValues.put(confKey, currentConfValue); + } + }); + + sparkSqlConf.forEach( + (confKey, confValue) -> { + if (SQLConf.isStaticConfigKey(confKey)) { + throw new RuntimeException( + "Cannot modify the value of a static config: " + confKey); + } + sqlConf.setConfString(confKey, confValue); + }); + try { + action.invoke(); + } finally { + sparkSqlConf.forEach( + (confKey, confValue) -> { + if (currentConfValues.containsKey(confKey)) { + sqlConf.setConfString(confKey, currentConfValues.get(confKey)); + } else { + sqlConf.unsetConf(confKey); + } + }); + } + } + + /** Action functional interface. */ + @FunctionalInterface + public interface Action { + void invoke(); + } +} diff --git a/paimon-benchmark/pom.xml b/paimon-benchmark/pom.xml index 75e9a1d51ba4..e8093ac79914 100644 --- a/paimon-benchmark/pom.xml +++ b/paimon-benchmark/pom.xml @@ -36,6 +36,7 @@ under the License. paimon-cluster-benchmark paimon-micro-benchmarks + paimon-spark-benchmark diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConversions.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConversions.java new file mode 100644 index 000000000000..c4f0e475a358 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkConversions.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark; + +import org.apache.paimon.utils.Preconditions; + +import org.apache.spark.sql.types.Decimal; + +import java.util.function.Function; + +/** Data conversions between Spark and Paimon. */ +public class SparkConversions { + private SparkConversions() {} + + public static Function decimalSparkToPaimon( + int precision, int scale) { + Preconditions.checkArgument( + precision <= 38, + "Decimals with precision larger than 38 are not supported: %s", + precision); + + if (org.apache.paimon.data.Decimal.isCompact(precision)) { + return sparkDecimal -> + org.apache.paimon.data.Decimal.fromUnscaledLong( + sparkDecimal.toUnscaledLong(), precision, scale); + } else { + return sparkDecimal -> + org.apache.paimon.data.Decimal.fromBigDecimal( + sparkDecimal.toJavaBigDecimal(), precision, scale); + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java new file mode 100644 index 000000000000..deae55b4f4d2 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.InternalArray; +import org.apache.paimon.data.InternalMap; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.data.variant.Variant; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.MapType; +import org.apache.paimon.types.RowKind; +import org.apache.paimon.types.RowType; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; + +import java.io.Serializable; + +/** Wrapper of Spark {@link InternalRow}. */ +public class SparkInternalRowWrapper implements org.apache.paimon.data.InternalRow, Serializable { + private final RowType paimonRowType; + private final InternalRow sparkInternalRow; + private final RowKind rowKind; + + public SparkInternalRowWrapper(RowType paimonType, InternalRow sparkInternalRow) { + this(paimonType, RowKind.INSERT, sparkInternalRow); + } + + public SparkInternalRowWrapper( + RowType paimonType, RowKind rowKind, InternalRow sparkInternalRow) { + this.paimonRowType = paimonType; + this.rowKind = rowKind; + this.sparkInternalRow = sparkInternalRow; + } + + @Override + public int getFieldCount() { + return sparkInternalRow.numFields(); + } + + @Override + public RowKind getRowKind() { + return rowKind; + } + + @Override + public void setRowKind(RowKind rowKind) { + throw new UnsupportedOperationException( + "SparkInternalRowWrapper does not support modifying row kind"); + } + + @Override + public boolean isNullAt(int i) { + return sparkInternalRow.isNullAt(i); + } + + @Override + public boolean getBoolean(int i) { + return sparkInternalRow.getBoolean(i); + } + + @Override + public byte getByte(int i) { + return sparkInternalRow.getByte(i); + } + + @Override + public short getShort(int i) { + return sparkInternalRow.getShort(i); + } + + @Override + public int getInt(int i) { + return sparkInternalRow.getInt(i); + } + + @Override + public long getLong(int i) { + return sparkInternalRow.getLong(i); + } + + @Override + public float getFloat(int i) { + return sparkInternalRow.getFloat(i); + } + + @Override + public double getDouble(int i) { + return sparkInternalRow.getDouble(i); + } + + @Override + public BinaryString getString(int i) { + return sparkInternalRow.isNullAt(i) + ? null + : BinaryString.fromBytes(sparkInternalRow.getUTF8String(i).getBytes()); + } + + @Override + public Decimal getDecimal(int i, int precision, int scale) { + return sparkInternalRow.isNullAt(i) + ? null + : Decimal.fromBigDecimal( + sparkInternalRow.getDecimal(i, precision, scale).toJavaBigDecimal(), + precision, + scale); + } + + @Override + public Timestamp getTimestamp(int i, int precision) { + return sparkInternalRow.isNullAt(i) + ? null + : Timestamp.fromMicros(sparkInternalRow.getLong(i)); + } + + @Override + public byte[] getBinary(int i) { + return sparkInternalRow.getBinary(i); + } + + @Override + public Variant getVariant(int pos) { + // return SparkShimLoader.getSparkShim().toPaimonVariant(row.getAs(i)); + // TODO + throw new UnsupportedOperationException(); + } + + @Override + public InternalArray getArray(int i) { + return sparkInternalRow.isNullAt(i) + ? null + : new SparkArrayDataWrapper( + ((ArrayType) paimonRowType.getTypeAt(i)).getElementType(), + sparkInternalRow.getArray(i)); + } + + @Override + public InternalMap getMap(int i) { + return sparkInternalRow.isNullAt(i) + ? null + : toPaimonInternalMap( + (MapType) paimonRowType.getTypeAt(i), sparkInternalRow.getMap(i)); + } + + @Override + public org.apache.paimon.data.InternalRow getRow(int i, int numFields) { + return sparkInternalRow.isNullAt(i) + ? null + : new SparkInternalRowWrapper( + (RowType) paimonRowType.getTypeAt(i), + sparkInternalRow.getStruct(i, numFields)); + } + + private static InternalMap toPaimonInternalMap(MapType mapType, MapData sparkMapData) { + SparkArrayDataWrapper keyArrayWrapper = + new SparkArrayDataWrapper(mapType.getKeyType(), sparkMapData.keyArray()); + SparkArrayDataWrapper valueArrayWrapper = + new SparkArrayDataWrapper(mapType.getValueType(), sparkMapData.valueArray()); + return new InternalMap() { + @Override + public int size() { + return sparkMapData.numElements(); + } + + @Override + public InternalArray keyArray() { + return keyArrayWrapper; + } + + @Override + public InternalArray valueArray() { + return valueArrayWrapper; + } + }; + } + + private static class SparkArrayDataWrapper implements InternalArray { + + private final DataType elementPaimonType; + private final ArrayData sparkArray; + + private SparkArrayDataWrapper(DataType elementPaimonType, ArrayData sparkArray) { + this.sparkArray = sparkArray; + this.elementPaimonType = elementPaimonType; + } + + @Override + public int size() { + return sparkArray.numElements(); + } + + @Override + public boolean isNullAt(int i) { + return sparkArray.isNullAt(i); + } + + @Override + public boolean getBoolean(int i) { + return sparkArray.getBoolean(i); + } + + @Override + public byte getByte(int i) { + return sparkArray.getByte(i); + } + + @Override + public short getShort(int i) { + return sparkArray.getShort(i); + } + + @Override + public int getInt(int i) { + return sparkArray.getInt(i); + } + + @Override + public long getLong(int i) { + return sparkArray.getLong(i); + } + + @Override + public float getFloat(int i) { + return sparkArray.getFloat(i); + } + + @Override + public double getDouble(int i) { + return sparkArray.getDouble(i); + } + + @Override + public BinaryString getString(int i) { + return sparkArray.isNullAt(i) + ? null + : BinaryString.fromBytes(sparkArray.getUTF8String(i).getBytes()); + } + + @Override + public Decimal getDecimal(int i, int precision, int scale) { + return sparkArray.isNullAt(i) + ? null + : Decimal.fromBigDecimal( + sparkArray.getDecimal(i, precision, scale).toJavaBigDecimal(), + precision, + scale); + } + + @Override + public Timestamp getTimestamp(int i, int precision) { + return sparkArray.isNullAt(i) ? null : Timestamp.fromMicros(sparkArray.getLong(i)); + } + + @Override + public byte[] getBinary(int i) { + return sparkArray.getBinary(i); + } + + @Override + public Variant getVariant(int pos) { + // TODO + throw new UnsupportedOperationException(); + } + + @Override + public InternalArray getArray(int i) { + return sparkArray.isNullAt(i) + ? null + : new SparkArrayDataWrapper( + ((ArrayType) elementPaimonType).getElementType(), + sparkArray.getArray(i)); + } + + @Override + public InternalMap getMap(int i) { + return sparkArray.isNullAt(i) + ? null + : toPaimonInternalMap((MapType) elementPaimonType, sparkArray.getMap(i)); + } + + @Override + public org.apache.paimon.data.InternalRow getRow(int i, int numFields) { + return sparkArray.isNullAt(i) + ? null + : new SparkInternalRowWrapper( + (RowType) elementPaimonType, sparkArray.getStruct(i, numFields)); + } + + @Override + public boolean[] toBooleanArray() { + return sparkArray.toBooleanArray(); + } + + @Override + public byte[] toByteArray() { + return sparkArray.toByteArray(); + } + + @Override + public short[] toShortArray() { + return sparkArray.toShortArray(); + } + + @Override + public int[] toIntArray() { + return sparkArray.toIntArray(); + } + + @Override + public long[] toLongArray() { + return sparkArray.toLongArray(); + } + + @Override + public float[] toFloatArray() { + return sparkArray.toFloatArray(); + } + + @Override + public double[] toDoubleArray() { + return sparkArray.toDoubleArray(); + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteRequirement.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteRequirement.java new file mode 100644 index 000000000000..8d4a39becca7 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteRequirement.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark; + +import org.apache.paimon.table.BucketMode; +import org.apache.paimon.table.BucketSpec; +import org.apache.paimon.table.FileStoreTable; + +import org.apache.paimon.shade.guava30.com.google.common.collect.Lists; + +import org.apache.spark.sql.connector.distributions.ClusteredDistribution; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.SortOrder; + +import java.util.List; + +/** Distribution and ordering requirements of spark write. */ +public class SparkWriteRequirement { + public static final SparkWriteRequirement EMPTY = + new SparkWriteRequirement(Distributions.unspecified()); + + private static final SortOrder[] EMPTY_ORDERING = new SortOrder[0]; + private final Distribution distribution; + + public static SparkWriteRequirement of(FileStoreTable table) { + BucketSpec bucketSpec = table.bucketSpec(); + BucketMode bucketMode = bucketSpec.getBucketMode(); + switch (bucketMode) { + case HASH_FIXED: + case BUCKET_UNAWARE: + break; + default: + throw new UnsupportedOperationException( + String.format("Unsupported bucket mode %s", bucketMode)); + } + + List clusteringExpressions = Lists.newArrayList(); + + List partitionKeys = table.schema().partitionKeys(); + for (String partitionKey : partitionKeys) { + clusteringExpressions.add(Expressions.identity(quote(partitionKey))); + } + + if (bucketMode == BucketMode.HASH_FIXED) { + String[] quotedBucketKeys = + bucketSpec.getBucketKeys().stream() + .map(SparkWriteRequirement::quote) + .toArray(String[]::new); + clusteringExpressions.add( + Expressions.bucket(bucketSpec.getNumBuckets(), quotedBucketKeys)); + } + + if (clusteringExpressions.isEmpty()) { + return EMPTY; + } + + ClusteredDistribution distribution = + Distributions.clustered( + clusteringExpressions.toArray( + clusteringExpressions.toArray(new Expression[0]))); + + return new SparkWriteRequirement(distribution); + } + + private static String quote(String columnName) { + return String.format("`%s`", columnName.replace("`", "``")); + } + + private SparkWriteRequirement(Distribution distribution) { + this.distribution = distribution; + } + + public Distribution distribution() { + return distribution; + } + + public SortOrder[] ordering() { + return EMPTY_ORDERING; + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/BucketFunction.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/BucketFunction.java new file mode 100644 index 000000000000..34bb769f0ef9 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/BucketFunction.java @@ -0,0 +1,724 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.catalog.functions; + +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.BinaryRowWriter; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.spark.SparkConversions; +import org.apache.paimon.table.BucketMode; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.MapType; +import org.apache.paimon.types.RowType; +import org.apache.paimon.types.VariantType; + +import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.lang.reflect.InvocationTargetException; +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** A Spark function implementation for the Paimon bucket transform. */ +public class BucketFunction implements UnboundFunction { + private static final int NUM_BUCKETS_ORDINAL = 0; + private static final int SPARK_TIMESTAMP_PRECISION = 6; + + private static final Map> BUCKET_FUNCTIONS; + + static { + ImmutableMap.Builder> builder = + ImmutableMap.builder(); + builder.put("BucketBoolean", BucketBoolean.class); + builder.put("BucketByte", BucketByte.class); + builder.put("BucketShort", BucketShort.class); + builder.put("BucketInteger", BucketInteger.class); + builder.put("BucketLong", BucketLong.class); + builder.put("BucketFloat", BucketFloat.class); + builder.put("BucketDouble", BucketDouble.class); + builder.put("BucketString", BucketString.class); + builder.put("BucketDecimal", BucketDecimal.class); + builder.put("BucketTimestamp", BucketTimestamp.class); + builder.put("BucketBinary", BucketBinary.class); + + // Joint bucket fields of common types + builder.put("BucketIntegerInteger", BucketIntegerInteger.class); + builder.put("BucketIntegerLong", BucketIntegerLong.class); + builder.put("BucketIntegerString", BucketIntegerString.class); + builder.put("BucketLongInteger", BucketLongInteger.class); + builder.put("BucketLongLong", BucketLongLong.class); + builder.put("BucketLongString", BucketLongString.class); + builder.put("BucketStringInteger", BucketStringInteger.class); + builder.put("BucketStringLong", BucketStringLong.class); + builder.put("BucketStringString", BucketStringString.class); + + BUCKET_FUNCTIONS = builder.build(); + } + + public static boolean supportsTable(FileStoreTable table) { + if (table.bucketMode() != BucketMode.HASH_FIXED) { + return false; + } + + return table.schema().logicalBucketKeyType().getFieldTypes().stream() + .allMatch(BucketFunction::supportsType); + } + + private static boolean supportsType(org.apache.paimon.types.DataType type) { + if (type instanceof ArrayType + || type instanceof MapType + || type instanceof RowType + || type instanceof VariantType) { + return false; + } + + if (type instanceof org.apache.paimon.types.TimestampType) { + return ((org.apache.paimon.types.TimestampType) type).getPrecision() + == SPARK_TIMESTAMP_PRECISION; + } + + if (type instanceof org.apache.paimon.types.LocalZonedTimestampType) { + return ((org.apache.paimon.types.LocalZonedTimestampType) type).getPrecision() + == SPARK_TIMESTAMP_PRECISION; + } + + return true; + } + + @Override + public BoundFunction bind(StructType inputType) { + StructField[] fields = inputType.fields(); + + StringBuilder classNameBuilder = new StringBuilder("Bucket"); + DataType[] bucketKeyTypes = new DataType[fields.length - 1]; + for (int i = 1; i < fields.length; i += 1) { + DataType dataType = fields[i].dataType(); + bucketKeyTypes[i - 1] = dataType; + if (dataType instanceof BooleanType) { + classNameBuilder.append("Boolean"); + } else if (dataType instanceof ByteType) { + classNameBuilder.append("Byte"); + } else if (dataType instanceof ShortType) { + classNameBuilder.append("Short"); + } else if (dataType instanceof IntegerType) { + classNameBuilder.append("Integer"); + } else if (dataType instanceof LongType) { + classNameBuilder.append("Long"); + } else if (dataType instanceof FloatType) { + classNameBuilder.append("Float"); + } else if (dataType instanceof DoubleType) { + classNameBuilder.append("Double"); + } else if (dataType instanceof StringType) { + classNameBuilder.append("String"); + } else if (dataType instanceof DecimalType) { + classNameBuilder.append("Decimal"); + } else if (dataType instanceof TimestampType || dataType instanceof TimestampNTZType) { + classNameBuilder.append("Timestamp"); + } else if (dataType instanceof BinaryType) { + classNameBuilder.append("Binary"); + } else { + throw new UnsupportedOperationException( + "Unsupported type: " + dataType.simpleString()); + } + } + + Class bucketClass = + BUCKET_FUNCTIONS.getOrDefault(classNameBuilder.toString(), BucketGeneric.class); + + try { + return bucketClass + .getConstructor(DataType[].class) + .newInstance( + (Object) + bucketKeyTypes /* cast DataType[] to Object as newInstance takes varargs */); + } catch (InstantiationException + | IllegalAccessException + | InvocationTargetException + | NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + public String description() { + return name() + "(numBuckets, col1, col2, ...)"; + } + + @Override + public String name() { + return "bucket"; + } + + /** Bound bucket function for generic type. */ + public static class BucketGeneric implements ScalarFunction { + protected final DataType[] bucketKeyTypes; + protected final BinaryRow bucketKeyRow; + // not serializable + protected transient BinaryRowWriter bucketKeyWriter; + private transient ValueWriter[] valueWriters; + + public BucketGeneric(DataType[] sqlTypes) { + this.bucketKeyTypes = sqlTypes; + this.bucketKeyRow = new BinaryRow(bucketKeyTypes.length); + this.bucketKeyWriter = new BinaryRowWriter(bucketKeyRow); + this.valueWriters = createValueWriter(bucketKeyTypes); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + this.bucketKeyWriter = new BinaryRowWriter(bucketKeyRow); + this.valueWriters = createValueWriter(bucketKeyTypes); + } + + private static ValueWriter[] createValueWriter(DataType[] columnTypes) { + ValueWriter[] writers = new ValueWriter[columnTypes.length]; + for (int i = 0; i < columnTypes.length; i += 1) { + writers[i] = ValueWriter.of(columnTypes[i]); + } + + return writers; + } + + @Override + public Integer produceResult(InternalRow input) { + bucketKeyWriter.reset(); + for (int i = 0; i < valueWriters.length; i += 1) { + valueWriters[i].write(bucketKeyWriter, i, input, i + 1); + } + + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, input.getInt(NUM_BUCKETS_ORDINAL)); + } + + @Override + public DataType[] inputTypes() { + DataType[] inputTypes = new DataType[bucketKeyTypes.length + 1]; + inputTypes[0] = DataTypes.IntegerType; + for (int i = 0; i < bucketKeyTypes.length; i += 1) { + inputTypes[i + 1] = bucketKeyTypes[i]; + } + + return inputTypes; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public boolean isResultNullable() { + return false; + } + + @Override + public String name() { + return "bucket"; + } + + @Override + public String canonicalName() { + return String.format( + "paimon.bucket(%s)", + Arrays.stream(bucketKeyTypes) + .map(DataType::catalogString) + .collect(Collectors.joining(", "))); + } + + // org.apache.paimon.table.sink.KeyAndBucketExtractor.bucket(int, int) + protected static int bucket(BinaryRow bucketKey, int numBuckets) { + return Math.abs(bucketKey.hashCode() % numBuckets); + } + } + + /** Bound bucket function for {Boolean} type. */ + public static class BucketBoolean extends BucketGeneric { + + public BucketBoolean(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, boolean value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeBoolean(0, value0); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {Byte} type. */ + public static class BucketByte extends BucketGeneric { + + public BucketByte(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, byte value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeByte(0, value0); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {Short} type. */ + public static class BucketShort extends BucketGeneric { + + public BucketShort(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, short value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeShort(0, value0); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {Integer} type. */ + public static class BucketInteger extends BucketGeneric { + + public BucketInteger(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, int value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeInt(0, value0); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {Long} type. */ + public static class BucketLong extends BucketGeneric { + + public BucketLong(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, long value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeLong(0, value0); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {Float} type. */ + public static class BucketFloat extends BucketGeneric { + + public BucketFloat(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, float value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeFloat(0, value0); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {Double} type. */ + public static class BucketDouble extends BucketGeneric { + + public BucketDouble(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, double value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeDouble(0, value0); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {String} type. */ + public static class BucketString extends BucketGeneric { + + public BucketString(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, UTF8String value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeString(0, BinaryString.fromBytes(value0.getBytes())); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {Decimal} type. */ + public static class BucketDecimal extends BucketGeneric { + private final int precision; + private final int scale; + private final Function converter; + + public BucketDecimal(DataType[] sqlTypes) { + super(sqlTypes); + DecimalType decimalType = (DecimalType) sqlTypes[0]; + this.precision = decimalType.precision(); + this.scale = decimalType.scale(); + this.converter = SparkConversions.decimalSparkToPaimon(precision, scale); + } + + public int invoke(int numBuckets, Decimal value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeDecimal(0, converter.apply(value0), precision); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for paimon {@code TimestampType} type with precision 6. */ + public static class BucketTimestamp extends BucketGeneric { + + public BucketTimestamp(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, long value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeTimestamp( + 0, Timestamp.fromMicros(value0), SPARK_TIMESTAMP_PRECISION); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bound bucket function for {Binary} type. */ + public static class BucketBinary extends BucketGeneric { + + public BucketBinary(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, byte[] value0) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeBinary(0, value0); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {Integer, Integer} composite type. */ + public static class BucketIntegerInteger extends BucketGeneric { + + public BucketIntegerInteger(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, int value0, int value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeInt(0, value0); + bucketKeyWriter.writeInt(1, value1); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {Integer, Long} composite type. */ + public static class BucketIntegerLong extends BucketGeneric { + + public BucketIntegerLong(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, int value0, long value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeInt(0, value0); + bucketKeyWriter.writeLong(0, value1); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {Integer, String} composite type. */ + public static class BucketIntegerString extends BucketGeneric { + + public BucketIntegerString(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, int value0, UTF8String value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeInt(0, value0); + bucketKeyWriter.writeString(1, BinaryString.fromBytes(value1.getBytes())); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {Long, Long} composite type. */ + public static class BucketLongLong extends BucketGeneric { + + public BucketLongLong(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, long value0, long value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeLong(0, value0); + bucketKeyWriter.writeLong(1, value1); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {Long, Integer} composite type. */ + public static class BucketLongInteger extends BucketGeneric { + + public BucketLongInteger(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, long value0, int value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeLong(0, value0); + bucketKeyWriter.writeInt(0, value1); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {Long, String} composite type. */ + public static class BucketLongString extends BucketGeneric { + + public BucketLongString(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, long value0, UTF8String value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeLong(0, value0); + bucketKeyWriter.writeString(1, BinaryString.fromBytes(value1.getBytes())); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {String, Integer} composite type. */ + public static class BucketStringInteger extends BucketGeneric { + + public BucketStringInteger(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, UTF8String value0, int value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeString(0, BinaryString.fromBytes(value0.getBytes())); + bucketKeyWriter.writeInt(1, value1); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {String, Long} composite type. */ + public static class BucketStringLong extends BucketGeneric { + + public BucketStringLong(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, UTF8String value0, long value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeString(0, BinaryString.fromBytes(value0.getBytes())); + bucketKeyWriter.writeLong(1, value1); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + /** Bucket function for {String, String} composite type. */ + public static class BucketStringString extends BucketGeneric { + + public BucketStringString(DataType[] sqlTypes) { + super(sqlTypes); + } + + public int invoke(int numBuckets, UTF8String value0, UTF8String value1) { + bucketKeyWriter.reset(); + bucketKeyWriter.writeString(0, BinaryString.fromBytes(value0.getBytes())); + bucketKeyWriter.writeString(1, BinaryString.fromBytes(value1.getBytes())); + bucketKeyWriter.complete(); + return bucket(bucketKeyRow, numBuckets); + } + } + + @FunctionalInterface + interface ValueWriter { + void write(BinaryRowWriter writer, int writePos, InternalRow srcRow, int srcPos); + + static ValueWriter of(DataType type) { + if (type instanceof BooleanType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeBoolean(writePos, srcRow.getBoolean(srcPos)); + } + }; + + } else if (type instanceof ByteType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeByte(writePos, srcRow.getByte(srcPos)); + } + }; + + } else if (type instanceof ShortType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeShort(writePos, srcRow.getShort(srcPos)); + } + }; + + } else if (type instanceof IntegerType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeInt(writePos, srcRow.getInt(srcPos)); + } + }; + + } else if (type instanceof LongType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeLong(writePos, srcRow.getLong(srcPos)); + } + }; + + } else if (type instanceof FloatType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeFloat(writePos, srcRow.getFloat(srcPos)); + } + }; + + } else if (type instanceof DoubleType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeDouble(writePos, srcRow.getDouble(srcPos)); + } + }; + + } else if (type instanceof StringType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeString( + writePos, + BinaryString.fromBytes(srcRow.getUTF8String(srcPos).getBytes())); + } + }; + + } else if (type instanceof BinaryType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + writer.setNullAt(writePos); + } else { + writer.writeBinary(writePos, srcRow.getBinary(srcPos)); + } + }; + + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType) type; + int precision = decimalType.precision(); + int scale = decimalType.scale(); + boolean compact = org.apache.paimon.data.Decimal.isCompact(precision); + Function converter = + SparkConversions.decimalSparkToPaimon(precision, scale); + + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + // org.apache.paimon.codegen.GenerateUtils.binaryWriterWriteNull + if (!compact) { + writer.writeDecimal(writePos, null, precision); + } else { + writer.setNullAt(writePos); + } + } else { + Decimal decimal = srcRow.getDecimal(srcPos, precision, scale); + writer.writeDecimal(writePos, converter.apply(decimal), precision); + } + }; + + } else if (type instanceof TimestampType || type instanceof TimestampNTZType) { + return (writer, writePos, srcRow, srcPos) -> { + if (srcRow.isNullAt(srcPos)) { + // org.apache.paimon.codegen.GenerateUtils.binaryWriterWriteNull + // must not be compacted as only Spark default precision 6 should be allowed + writer.writeTimestamp(writePos, null, SPARK_TIMESTAMP_PRECISION); + } else { + writer.writeTimestamp( + writePos, + Timestamp.fromMicros(srcRow.getLong(srcPos)), + SPARK_TIMESTAMP_PRECISION); + } + }; + + } else { + throw new UnsupportedOperationException("Unsupported type: " + type.simpleString()); + } + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java index c7949e11948f..b265419b535b 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java @@ -35,7 +35,6 @@ import java.util.Map; import static org.apache.paimon.utils.Preconditions.checkArgument; -import static org.apache.spark.sql.types.DataTypes.IntegerType; import static org.apache.spark.sql.types.DataTypes.StringType; /** Paimon functions. */ @@ -57,60 +56,6 @@ public static UnboundFunction load(String name) { return FUNCTIONS.get(name); } - /** - * For now, we only support report bucket partitioning for table scan. So the case `SELECT - * bucket(10, col)` would fail since we do not implement {@link - * org.apache.spark.sql.connector.catalog.functions.ScalarFunction} - */ - public static class BucketFunction implements UnboundFunction { - @Override - public BoundFunction bind(StructType inputType) { - if (inputType.size() != 2) { - throw new UnsupportedOperationException( - "Wrong number of inputs (expected numBuckets and value)"); - } - - StructField numBucket = inputType.fields()[0]; - StructField bucketField = inputType.fields()[1]; - checkArgument( - numBucket.dataType() == IntegerType, - "bucket number field must be integer type"); - - return new BoundFunction() { - @Override - public DataType[] inputTypes() { - return new DataType[] {IntegerType, bucketField.dataType()}; - } - - @Override - public DataType resultType() { - return IntegerType; - } - - @Override - public String name() { - return "bucket"; - } - - @Override - public String canonicalName() { - // We have to override this method to make it support canonical equivalent - return "paimon.bucket(" + bucketField.dataType().catalogString() + ", int)"; - } - }; - } - - @Override - public String description() { - return name(); - } - - @Override - public String name() { - return "bucket"; - } - } - /** * For partitioned tables, this function returns the maximum value of the first level partition * of the partitioned table, sorted alphabetically. Note, empty partitions will be skipped. For diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/source/SparkV2Write.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/source/SparkV2Write.java new file mode 100644 index 000000000000..82b43b8d6557 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/source/SparkV2Write.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.source; + +import org.apache.paimon.disk.IOManager; +import org.apache.paimon.spark.SparkInternalRowWrapper; +import org.apache.paimon.spark.SparkUtils; +import org.apache.paimon.spark.SparkWriteRequirement; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.table.sink.BatchTableCommit; +import org.apache.paimon.table.sink.BatchTableWrite; +import org.apache.paimon.table.sink.BatchWriteBuilder; +import org.apache.paimon.table.sink.CommitMessage; +import org.apache.paimon.table.sink.CommitMessageSerializer; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.Preconditions; + +import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap; +import org.apache.paimon.shade.guava30.com.google.common.collect.Lists; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.expressions.SortOrder; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.RequiresDistributionAndOrdering; +import org.apache.spark.sql.connector.write.Write; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Map; + +import static org.apache.paimon.CoreOptions.DYNAMIC_PARTITION_OVERWRITE; + +/** Spark V2WRite. */ +public class SparkV2Write implements Write, RequiresDistributionAndOrdering { + private static final Logger LOG = LoggerFactory.getLogger(SparkV2Write.class); + + private final FileStoreTable table; + private final BatchWriteBuilder writeBuilder; + private final SparkWriteRequirement writeRequirement; + private final boolean overwriteDynamic; + private final Map overwritePartitions; + + public SparkV2Write( + FileStoreTable newTable, + boolean overwriteDynamic, + Map overwritePartitions) { + Preconditions.checkArgument( + !overwriteDynamic || overwritePartitions == null, + "Cannot overwrite dynamically and by filter both"); + + this.table = + newTable.copy( + ImmutableMap.of( + DYNAMIC_PARTITION_OVERWRITE.key(), + Boolean.toString(overwriteDynamic))); + this.writeBuilder = table.newBatchWriteBuilder(); + if (overwritePartitions != null) { + writeBuilder.withOverwrite(overwritePartitions); + } + + this.writeRequirement = SparkWriteRequirement.of(table); + this.overwriteDynamic = overwriteDynamic; + this.overwritePartitions = overwritePartitions; + } + + @Override + public Distribution requiredDistribution() { + Distribution distribution = writeRequirement.distribution(); + LOG.info("Requesting {} as write distribution for table {}", distribution, table.name()); + return distribution; + } + + @Override + public SortOrder[] requiredOrdering() { + SortOrder[] ordering = writeRequirement.ordering(); + LOG.info("Requesting {} as write ordering for table {}", ordering, table.name()); + return ordering; + } + + @Override + public BatchWrite toBatch() { + return new PaimonBatchWrite(); + } + + private class PaimonBatchWrite implements BatchWrite { + + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + return new WriterFactory(table.rowType(), writeBuilder); + } + + @Override + public boolean useCommitCoordinator() { + return false; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + LOG.info("Committing to table {}, ", table.name()); + BatchTableCommit batchTableCommit = writeBuilder.newCommit(); + List allCommitMessage = Lists.newArrayList(); + + for (WriterCommitMessage message : messages) { + if (message != null) { + List commitMessages = ((TaskCommit) message).commitMessages(); + allCommitMessage.addAll(commitMessages); + } + } + + try { + long start = System.currentTimeMillis(); + batchTableCommit.commit(allCommitMessage); + LOG.info("Committed in {} ms", System.currentTimeMillis() - start); + batchTableCommit.close(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void abort(WriterCommitMessage[] messages) { + // TODO abort + } + } + + @Override + public String toString() { + return String.format( + "PaimonWrite(table=%s, %s)", + table.name(), + overwriteDynamic + ? "overwriteDynamic=true" + : String.format("overwritePartitions=%s", overwritePartitions)); + } + + private static class WriterFactory implements DataWriterFactory { + private final RowType rowType; + private final BatchWriteBuilder batchWriteBuilder; + + WriterFactory(RowType rowType, BatchWriteBuilder batchWriteBuilder) { + this.rowType = rowType; + this.batchWriteBuilder = batchWriteBuilder; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + BatchTableWrite batchTableWrite = batchWriteBuilder.newWrite(); + return new GenericWriter(batchTableWrite, rowType); + } + } + + private static class GenericWriter implements DataWriter { + private final BatchTableWrite batchTableWrite; + private final RowType rowType; + private final IOManager ioManager; + + private GenericWriter(BatchTableWrite batchTableWrite, RowType rowType) { + this.batchTableWrite = batchTableWrite; + this.rowType = rowType; + this.ioManager = SparkUtils.createIOManager(); + batchTableWrite.withIOManager(ioManager); + } + + @Override + public void write(InternalRow record) throws IOException { + // TODO rowKind + SparkInternalRowWrapper wrappedInternalRow = + new SparkInternalRowWrapper(rowType, record); + try { + batchTableWrite.write(wrappedInternalRow); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public WriterCommitMessage commit() throws IOException { + try { + List commitMessages = batchTableWrite.prepareCommit(); + TaskCommit taskCommit = new TaskCommit(commitMessages); + return taskCommit; + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + close(); + } + } + + @Override + public void abort() throws IOException { + // TODO clean uncommitted files + close(); + } + + @Override + public void close() throws IOException { + try { + batchTableWrite.close(); + ioManager.close(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + private static class TaskCommit implements WriterCommitMessage { + private final List serializedMessageList = Lists.newArrayList(); + + TaskCommit(List commitMessages) { + if (commitMessages == null || commitMessages.isEmpty()) { + return; + } + + CommitMessageSerializer serializer = new CommitMessageSerializer(); + for (CommitMessage commitMessage : commitMessages) { + try { + serializedMessageList.add(serializer.serialize(commitMessage)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } + + List commitMessages() { + CommitMessageSerializer deserializer = new CommitMessageSerializer(); + List commitMessageList = Lists.newArrayList(); + for (byte[] serializedMessage : serializedMessageList) { + try { + commitMessageList.add( + deserializer.deserialize(deserializer.getVersion(), serializedMessage)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + return commitMessageList; + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/BaseWriteBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/BaseWriteBuilder.scala new file mode 100644 index 000000000000..c6b6e961bfe0 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/BaseWriteBuilder.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark + +import org.apache.paimon.options.Options +import org.apache.paimon.table.FileStoreTable +import org.apache.paimon.types.RowType + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.connector.write.WriteBuilder +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, Not, Or} + +import scala.collection.JavaConverters._ + +abstract class BaseWriteBuilder(table: FileStoreTable, options: Options) + extends WriteBuilder + with SQLConfHelper { + + private def failWithReason(filter: Filter): Unit = { + throw new RuntimeException( + s"Only support Overwrite filters with Equal and EqualNullSafe, but got: $filter") + } + + private def validateFilter(filter: Filter): Unit = filter match { + case And(left, right) => + validateFilter(left) + validateFilter(right) + case _: Or => failWithReason(filter) + case _: Not => failWithReason(filter) + case e: EqualTo if e.references.length == 1 && !e.value.isInstanceOf[Filter] => + case e: EqualNullSafe if e.references.length == 1 && !e.value.isInstanceOf[Filter] => + case _: AlwaysTrue | _: AlwaysFalse => + case _ => failWithReason(filter) + } + + // `SupportsOverwrite#canOverwrite` is added since Spark 3.4.0. + // We do this checking by self to work with previous Spark version. + protected def failIfCanNotOverwrite(filters: Array[Filter]): Unit = { + // For now, we only support overwrite with two cases: + // - overwrite with partition columns to be compatible with v1 insert overwrite + // See [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveInsertInto#staticDeleteExpression]]. + // - truncate-like overwrite and the filter is always true. + // + // Fast fail for other custom filters which through v2 write interface, e.g., + // `dataframe.writeTo(T).overwrite(...)` + val partitionRowType = table.schema.logicalPartitionType() + val partitionNames = partitionRowType.getFieldNames.asScala + val allReferences = filters.flatMap(_.references) + val containsDataColumn = allReferences.exists { + reference => !partitionNames.exists(conf.resolver.apply(reference, _)) + } + if (containsDataColumn) { + throw new RuntimeException( + s"Only support Overwrite filters on partition column ${partitionNames.mkString( + ", ")}, but got ${filters.mkString(", ")}.") + } + if (allReferences.distinct.length < allReferences.length) { + // fail with `part = 1 and part = 2` + throw new RuntimeException( + s"Only support Overwrite with one filter for each partition column, but got ${filters.mkString(", ")}.") + } + filters.foreach(validateFilter) + } + + private def parseSaveMode( + saveMode: SaveMode, + table: FileStoreTable): (Boolean, Map[String, String]) = { + var dynamicPartitionOverwriteMode = false + val overwritePartition = saveMode match { + case InsertInto => null + case Overwrite(filter) => + if (filter.isEmpty) { + Map.empty[String, String] + } else if (isTruncate(filter.get)) { + Map.empty[String, String] + } else { + convertPartitionFilterToMap(filter.get, table.schema.logicalPartitionType()) + } + case DynamicOverWrite => + dynamicPartitionOverwriteMode = true + Map.empty[String, String] + case _ => + throw new UnsupportedOperationException(s" This mode is unsupported for now.") + } + (dynamicPartitionOverwriteMode, overwritePartition) + } + + /** + * For the 'INSERT OVERWRITE' semantics of SQL, Spark DataSourceV2 will call the `truncate` + * methods where the `AlwaysTrue` Filter is used. + */ + def isTruncate(filter: Filter): Boolean = { + val filters = splitConjunctiveFilters(filter) + filters.length == 1 && filters.head.isInstanceOf[AlwaysTrue] + } + + /** See [[ org.apache.paimon.spark.SparkWriteBuilder#failIfCanNotOverwrite]] */ + def convertPartitionFilterToMap( + filter: Filter, + partitionRowType: RowType): Map[String, String] = { + // todo: replace it with SparkV2FilterConverter when we drop Spark3.2 + val converter = new SparkFilterConverter(partitionRowType) + splitConjunctiveFilters(filter).map { + case EqualNullSafe(attribute, value) => + (attribute, converter.convertString(attribute, value)) + case EqualTo(attribute, value) => + (attribute, converter.convertString(attribute, value)) + case _ => + // Should not happen + throw new RuntimeException( + s"Only support Overwrite filters with Equal and EqualNullSafe, but got: $filter") + }.toMap + } + + private def splitConjunctiveFilters(filter: Filter): Seq[Filter] = { + filter match { + case And(filter1, filter2) => + splitConjunctiveFilters(filter1) ++ splitConjunctiveFilters(filter2) + case other => other :: Nil + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkSQLProperties.java b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkSQLProperties.java new file mode 100644 index 000000000000..2101c95f2817 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkSQLProperties.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark; + +/** + * Spark SQL properties for Paimon. + */ +public class SparkSQLProperties { + private SparkSQLProperties(){} + + public static final String USE_V2_WRITE = "spark.sql.paimon.use-v2-write"; + public static final String USE_V2_WRITE_DEFAULT = "false"; +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala index b9a90d8b5bef..f6e6a5a51a9a 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala @@ -20,10 +20,12 @@ package org.apache.paimon.spark import org.apache.paimon.CoreOptions import org.apache.paimon.options.Options +import org.apache.paimon.spark.catalog.functions.BucketFunction import org.apache.paimon.spark.schema.PaimonMetadataColumn -import org.apache.paimon.table.{DataTable, FileStoreTable, KnownSplitsTable, Table} +import org.apache.paimon.table.{BucketMode, DataTable, FileStoreTable, KnownSplitsTable, Table} import org.apache.paimon.utils.StringUtils +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, SupportsWrite, TableCapability, TableCatalog} import org.apache.spark.sql.connector.expressions.{Expressions, Transform} import org.apache.spark.sql.connector.read.ScanBuilder @@ -43,6 +45,27 @@ case class SparkTable(table: Table) with SupportsMetadataColumns with PaimonPartitionManagement { + private lazy val sparkSession: SparkSession = SparkSession.active + private lazy val useV2Write: Boolean = { + val v2WriteConfigured = sparkSession.conf + .get(SparkSQLProperties.USE_V2_WRITE, SparkSQLProperties.USE_V2_WRITE_DEFAULT) + .toBoolean + v2WriteConfigured && supportsV2Write + } + + private def supportsV2Write: Boolean = { + table match { + case storeTable: FileStoreTable => + storeTable.bucketMode() match { + case BucketMode.HASH_FIXED => BucketFunction.supportsTable(storeTable) + case BucketMode.BUCKET_UNAWARE => true + case _ => false + } + + case _ => false + } + } + def getTable: Table = table override def name: String = table.fullName @@ -73,14 +96,18 @@ case class SparkTable(table: Table) } override def capabilities: JSet[TableCapability] = { - JEnumSet.of( + val capabilities = JEnumSet.of( TableCapability.ACCEPT_ANY_SCHEMA, TableCapability.BATCH_READ, - TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.OVERWRITE_DYNAMIC, TableCapability.MICRO_BATCH_READ ) + + capabilities.add( + if (useV2Write) TableCapability.BATCH_WRITE else TableCapability.V1_BATCH_WRITE) + + capabilities } override def metadataColumns: Array[MetadataColumn] = { @@ -105,7 +132,12 @@ case class SparkTable(table: Table) override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { table match { case fileStoreTable: FileStoreTable => - new SparkWriteBuilder(fileStoreTable, Options.fromMap(info.options)) + val options = Options.fromMap(info.options) + if (useV2Write) { + new SparkV2WriteBuilder(fileStoreTable, options) + } else { + new SparkWriteBuilder(fileStoreTable, options) + } case _ => throw new RuntimeException("Only FileStoreTable can be written.") } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2WriteBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2WriteBuilder.scala new file mode 100644 index 000000000000..377992a34720 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2WriteBuilder.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark + +import org.apache.paimon.options.Options +import org.apache.paimon.spark.source.SparkV2Write +import org.apache.paimon.table.FileStoreTable + +import org.apache.spark.sql.connector.write.{SupportsDynamicOverwrite, SupportsOverwrite, WriteBuilder} +import org.apache.spark.sql.sources.{And, Filter} + +import scala.collection.JavaConverters._ + +private class SparkV2WriteBuilder(table: FileStoreTable, options: Options) + extends BaseWriteBuilder(table, options) + with SupportsOverwrite + with SupportsDynamicOverwrite { + + private var overwriteDynamic = false + private var overwritePartitions: Map[String, String] = null + override def build = new SparkV2Write(table, overwriteDynamic, overwritePartitions.asJava) + override def overwrite(filters: Array[Filter]): WriteBuilder = { + if (overwriteDynamic) { + throw new IllegalArgumentException("Cannot overwrite dynamically and by filter both") + } + + failIfCanNotOverwrite(filters) + + val conjunctiveFilters = if (filters.nonEmpty) { + Some(filters.reduce((l, r) => And(l, r))) + } else { + None + } + + if (isTruncate(conjunctiveFilters.get)) { + overwritePartitions = Map.empty[String, String] + } else { + overwritePartitions = + convertPartitionFilterToMap(conjunctiveFilters.get, table.schema.logicalPartitionType()) + } + + this + } + + override def overwriteDynamicPartitions(): WriteBuilder = { + if (overwritePartitions != null) { + throw new IllegalArgumentException("Cannot overwrite dynamically and by filter both") + } + + overwriteDynamic = true + this + } +} diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java index fff94ce0374d..6006a93181c0 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkWriteITCase.java @@ -26,8 +26,12 @@ import org.apache.paimon.table.FileStoreTable; import org.apache.paimon.table.FileStoreTableFactory; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.internal.SQLConf; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; @@ -42,6 +46,7 @@ import java.util.List; import java.util.stream.Collectors; +import static org.apache.spark.sql.functions.expr; import static org.assertj.core.api.Assertions.assertThat; /** ITCase for spark writer. */ @@ -81,6 +86,53 @@ public void afterEach() { spark.sql("DROP TABLE T"); } + @Test + public void testV2Write() { + spark.sql( + "CREATE TABLE T (a INT, b INT, c STRING) partitioned by (a) TBLPROPERTIES" + + " ('primary-key'='a,b', 'bucket'='4', 'file.format'='parquet')"); + + spark.conf().set(SparkSQLProperties.USE_V2_WRITE, true); + + spark.sql("INSERT INTO T VALUES(1, 1, 'a'), (2, 2, 'b')"); + spark.sql("INSERT INTO T VALUES (1, 1, 'A')"); + + List rows = spark.sql("SELECT * FROM T").collectAsList(); + assertThat(rows.toString()).isEqualTo("[[1,1,A], [2,2,b]]"); + } + + @Test + public void testStaticV2Overwrite() { + spark.sql( + "CREATE TABLE T (a INT, b INT, c STRING) partitioned by (a) TBLPROPERTIES" + + " ('primary-key'='a,b', 'bucket'='4', 'file.format'='parquet')"); + + spark.conf().set(SparkSQLProperties.USE_V2_WRITE, true); + spark.conf().set(SQLConf.PARTITION_OVERWRITE_MODE(), "STATIC"); + + spark.sql("INSERT INTO T VALUES(1, 1, 'a'), (2, 2, 'b')"); + spark.sql("INSERT OVERWRITE T VALUES (1, 1, 'A')"); + + List rows = spark.sql("SELECT * FROM T").collectAsList(); + assertThat(rows.toString()).isEqualTo("[[1,1,A]]"); + } + + @Test + public void testDynamicV2Overwrite() { + spark.sql( + "CREATE TABLE T (a INT, b INT, c STRING) partitioned by (a) TBLPROPERTIES" + + " ('primary-key'='a,b', 'bucket'='4', 'file.format'='parquet')"); + + spark.conf().set(SparkSQLProperties.USE_V2_WRITE, true); + spark.conf().set(SQLConf.PARTITION_OVERWRITE_MODE(), "DYNAMIC"); + + spark.sql("INSERT INTO T VALUES(1, 1, 'a'), (2, 2, 'b')"); + spark.sql("INSERT OVERWRITE T VALUES (1, 1, 'A')"); + + List rows = spark.sql("SELECT * FROM T").collectAsList(); + assertThat(rows.toString()).isEqualTo("[[2,2,b], [1,1,A]]"); + } + @Test public void testWrite() { spark.sql( @@ -89,6 +141,55 @@ public void testWrite() { innerSimpleWrite(); } + @Test + public void testWrite1() throws NoSuchTableException { + spark.sql( + "CREATE TABLE T (a bigint, b INT, c STRING) TBLPROPERTIES" + + " ('primary-key'='a', 'bucket'='4', 'file.format'='parquet')"); + Dataset df = + spark.range(10) + .withColumnRenamed("id", "a") + .withColumn("b", functions.expr("CAST(a as INT)")) + .withColumn("c", functions.expr("CAST(a as STRING)")) + .coalesce(1); + + df.writeTo("T").append(); + } + + @Test + public void testWrite2() throws NoSuchTableException { + + spark.sql( + String.format( + "CREATE TABLE %s (" + + "intCol INT NOT NULL, " + + "longCol BIGINT NOT NULL, " + + "floatCol FLOAT NOT NULL, " + + "doubleCol DOUBLE NOT NULL, " + + "decimalCol DECIMAL(20, 5) NOT NULL, " + + "stringCol1 STRING NOT NULL, " + + "stringCol2 STRING NOT NULL, " + + "stringCol3 STRING NOT NULL" + + ") using paimon " + + "TBLPROPERTIES('primary-key'='stringCol1,stringCol2', 'bucket'='4')", + "T")); + + benchmarkData().writeTo("T").append(); + } + + private Dataset benchmarkData() { + return spark.range(1) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(20, 5))")) + .withColumn("stringCol1", expr("CAST(longCol AS STRING)")) + .withColumn("stringCol2", expr("CAST(longCol AS STRING)")) + .withColumn("stringCol3", expr("CAST(longCol AS STRING)")) + .coalesce(1); + } + @Test public void testWritePartitionTable() { spark.sql( diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java new file mode 100644 index 000000000000..0ad77337a44f --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.catalog.functions; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.schema.TableSchema; +import org.apache.paimon.spark.SparkTypeUtils; +import org.apache.paimon.table.sink.FixedBucketRowKeyExtractor; +import org.apache.paimon.types.BigIntType; +import org.apache.paimon.types.BooleanType; +import org.apache.paimon.types.DataField; +import org.apache.paimon.types.DecimalType; +import org.apache.paimon.types.DoubleType; +import org.apache.paimon.types.FloatType; +import org.apache.paimon.types.IntType; +import org.apache.paimon.types.LocalZonedTimestampType; +import org.apache.paimon.types.RowType; +import org.apache.paimon.types.SmallIntType; +import org.apache.paimon.types.TimestampType; +import org.apache.paimon.types.TinyIntType; +import org.apache.paimon.types.VarBinaryType; +import org.apache.paimon.types.VarCharType; + +import org.apache.paimon.shade.guava30.com.google.common.base.Joiner; +import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap; + +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for Spark bucket functions. */ +public class BucketFunctionTest { + private static final int NUM_BUCKETS = + ThreadLocalRandom.current().nextInt(1, Integer.MAX_VALUE); + + private static final String BOOLEAN_COL = "boolean_col"; + private static final String BYTE_COL = "byte_col"; + private static final String SHORT_COL = "short_col"; + private static final String INTEGER_COL = "integer_col"; + private static final String LONG_COL = "long_col"; + private static final String FLOAT_COL = "float_col"; + private static final String DOUBLE_COL = "double_col"; + private static final String STRING_COL = "string_col"; + private static final String DECIMAL_COL = "decimal_col"; + private static final String COMPACTED_DECIMAL_COL = "compacted_decimal_col"; + private static final String TIMESTAMP_COL = "timestamp_col"; + private static final String LZ_TIMESTAMP_COL = "lz_timestamp_col"; + private static final String BINARY_COL = "binary_col"; + + private static final int DECIMAL_PRECISION = 38; + private static final int DECIMAL_SCALE = 18; + private static final int COMPACTED_DECIMAL_PRECISION = 18; + private static final int COMPACTED_DECIMAL_SCALE = 9; + private static final int TIMESTAMP_PRECISION = 6; + + private static final RowType ROW_TYPE = + new RowType( + Arrays.asList( + new DataField(0, BOOLEAN_COL, new BooleanType()), + new DataField(1, BYTE_COL, new TinyIntType()), + new DataField(2, SHORT_COL, new SmallIntType()), + new DataField(3, INTEGER_COL, new IntType()), + new DataField(4, LONG_COL, new BigIntType()), + new DataField(5, FLOAT_COL, new FloatType()), + new DataField(6, DOUBLE_COL, new DoubleType()), + new DataField(7, STRING_COL, new VarCharType(VarCharType.MAX_LENGTH)), + new DataField( + 8, + DECIMAL_COL, + new DecimalType(DECIMAL_PRECISION, DECIMAL_SCALE)), + new DataField( + 9, + COMPACTED_DECIMAL_COL, + new DecimalType( + COMPACTED_DECIMAL_PRECISION, COMPACTED_DECIMAL_SCALE)), + new DataField( + 10, TIMESTAMP_COL, new TimestampType(TIMESTAMP_PRECISION)), + new DataField( + 11, + LZ_TIMESTAMP_COL, + new LocalZonedTimestampType(TIMESTAMP_PRECISION)), + new DataField( + 12, BINARY_COL, new VarBinaryType(VarBinaryType.MAX_LENGTH)))); + + private static final StructType SPARK_TYPE = SparkTypeUtils.fromPaimonRowType(ROW_TYPE); + + private static InternalRow randomPaimonInternalRow() { + Random random = new Random(); + BigInteger unscaled = new BigInteger(String.valueOf(random.nextInt())); + BigDecimal bigDecimal1 = new BigDecimal(unscaled, DECIMAL_SCALE); + BigDecimal bigDecimal2 = new BigDecimal(unscaled, COMPACTED_DECIMAL_SCALE); + + return GenericRow.of( + random.nextBoolean(), + (byte) random.nextInt(), + (short) random.nextInt(), + random.nextInt(), + random.nextLong(), + random.nextFloat(), + random.nextDouble(), + BinaryString.fromString(UUID.randomUUID().toString()), + Decimal.fromBigDecimal(bigDecimal1, DECIMAL_PRECISION, DECIMAL_SCALE), + Decimal.fromBigDecimal( + bigDecimal2, COMPACTED_DECIMAL_PRECISION, COMPACTED_DECIMAL_SCALE), + Timestamp.now(), + Timestamp.now(), + UUID.randomUUID().toString().getBytes()); + } + + private static final InternalRow NULL_PAIMON_ROW = + GenericRow.of( + null, null, null, null, null, null, null, null, null, null, null, null, null); + + @Test + public void testBooleanType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {BOOLEAN_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testByteType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {BYTE_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testShortType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {SHORT_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testIntegerType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {INTEGER_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testLongType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {LONG_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testFloatType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {FLOAT_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testDoubleType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {DOUBLE_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testStringType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {STRING_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testBinaryType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {BINARY_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testTimestampType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {TIMESTAMP_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testLZTimestampType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {LZ_TIMESTAMP_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testDecimalType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {DECIMAL_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testCompactedDecimalType() { + InternalRow internalRow = randomPaimonInternalRow(); + String[] bucketColumns = {COMPACTED_DECIMAL_COL}; + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + @Test + public void testGenericType() { + InternalRow internalRow = randomPaimonInternalRow(); + + List allColumns = ROW_TYPE.getFieldNames(); + // the order of columns matters + Collections.shuffle(allColumns); + String[] bucketColumns = null; + while (bucketColumns == null || bucketColumns.length < 2) { + bucketColumns = + allColumns.stream() + .filter(e -> ThreadLocalRandom.current().nextBoolean()) + .toArray(String[]::new); + } + + assertBucketEquals(internalRow, bucketColumns); + assertBucketEquals(NULL_PAIMON_ROW, bucketColumns); + } + + private static void assertBucketEquals(InternalRow paimonRow, String... bucketColumns) { + assertThat(bucketOfSparkFunction(paimonRow, bucketColumns)) + .as( + String.format( + "The bucket computed by spark function should be identical to the bucket computed by paimon bucket extractor function, row: %s", + paimonRow)) + .isEqualTo(bucketOfPaimonFunction(paimonRow, bucketColumns)); + } + + private static Object[] columnSparkValues(InternalRow paimonRow, String... columns) { + Object[] values = new Object[columns.length]; + + for (int i = 0; i < columns.length; i += 1) { + String column = columns[i]; + org.apache.paimon.types.DataType paimonType = ROW_TYPE.getField(column).type(); + int fieldIndex = ROW_TYPE.getFieldIndex(column); + + if (paimonRow.isNullAt(fieldIndex)) { + values[i] = null; + continue; + } + + if (paimonType instanceof BooleanType) { + values[i] = paimonRow.getBoolean(fieldIndex); + } else if (paimonType instanceof TinyIntType) { + values[i] = paimonRow.getByte(fieldIndex); + } else if (paimonType instanceof SmallIntType) { + values[i] = paimonRow.getShort(fieldIndex); + } else if (paimonType instanceof IntType) { + values[i] = paimonRow.getInt(fieldIndex); + } else if (paimonType instanceof BigIntType) { + values[i] = paimonRow.getLong(fieldIndex); + } else if (paimonType instanceof FloatType) { + values[i] = paimonRow.getFloat(fieldIndex); + } else if (paimonType instanceof DoubleType) { + values[i] = paimonRow.getDouble(fieldIndex); + } else if (paimonType instanceof VarCharType) { + values[i] = UTF8String.fromBytes(paimonRow.getString(fieldIndex).toBytes()); + } else if (paimonType instanceof TimestampType + || paimonType instanceof LocalZonedTimestampType) { + values[i] = paimonRow.getTimestamp(fieldIndex, 9).toMicros(); + } else if (paimonType instanceof DecimalType) { + int precision = ((DecimalType) paimonType).getPrecision(); + int scale = ((DecimalType) paimonType).getScale(); + Decimal paimonDecimal = paimonRow.getDecimal(fieldIndex, precision, scale); + values[i] = + org.apache.spark.sql.types.Decimal.apply( + paimonDecimal.toBigDecimal(), precision, scale); + } else if (paimonType instanceof VarBinaryType) { + values[i] = paimonRow.getBinary(fieldIndex); + } else { + throw new UnsupportedOperationException( + "Unsupported type: " + paimonType.asSQLString()); + } + } + + return values; + } + + private static Object[] invokeParameters(Object[] sparkColumnValues) { + Object[] ret = new Object[sparkColumnValues.length + 1]; + ret[0] = NUM_BUCKETS; + System.arraycopy(sparkColumnValues, 0, ret, 1, sparkColumnValues.length); + return ret; + } + + private static StructType bucketColumnsSparkType(String... columns) { + StructField[] inputFields = new StructField[columns.length + 1]; + inputFields[0] = + new StructField("num_buckets", DataTypes.IntegerType, false, Metadata.empty()); + + for (int i = 0; i < columns.length; i += 1) { + String column = columns[i]; + inputFields[i + 1] = SPARK_TYPE.apply(column); + } + + return new StructType(inputFields); + } + + private static int bucketOfSparkFunction(InternalRow paimonRow, String... bucketColumns) { + BucketFunction unbound = new BucketFunction(); + StructType inputSqlType = bucketColumnsSparkType(bucketColumns); + BoundFunction boundFunction = unbound.bind(inputSqlType); + + Object[] inputValues = invokeParameters(columnSparkValues(paimonRow, bucketColumns)); + + // null values can only be handled by #produceResult + if (boundFunction.getClass() == BucketFunction.BucketGeneric.class + || Arrays.stream(inputValues).anyMatch(v -> v == null)) { + return ((BucketFunction.BucketGeneric) boundFunction) + .produceResult(new GenericInternalRow(inputValues)); + } else { + Class[] parameterTypes = new Class[inputSqlType.fields().length]; + StructField[] inputFields = inputSqlType.fields(); + for (int i = 0; i < inputSqlType.fields().length; i += 1) { + DataType columnType = inputFields[i].dataType(); + if (columnType == DataTypes.BooleanType) { + parameterTypes[i] = boolean.class; + } else if (columnType == DataTypes.ByteType) { + parameterTypes[i] = byte.class; + } else if (columnType == DataTypes.ShortType) { + parameterTypes[i] = short.class; + } else if (columnType == DataTypes.IntegerType) { + parameterTypes[i] = int.class; + } else if (columnType == DataTypes.LongType + || columnType == DataTypes.TimestampType + || columnType == DataTypes.TimestampNTZType) { + parameterTypes[i] = long.class; + } else if (columnType == DataTypes.FloatType) { + parameterTypes[i] = float.class; + } else if (columnType == DataTypes.DoubleType) { + parameterTypes[i] = double.class; + } else if (columnType == DataTypes.StringType) { + parameterTypes[i] = UTF8String.class; + } else if (columnType instanceof org.apache.spark.sql.types.DecimalType) { + parameterTypes[i] = org.apache.spark.sql.types.Decimal.class; + } else if (columnType == DataTypes.BinaryType) { + parameterTypes[i] = byte[].class; + } else { + throw new UnsupportedOperationException( + "Unsupported type: " + columnType.sql()); + } + } + + try { + Method invoke = boundFunction.getClass().getMethod("invoke", parameterTypes); + return (int) invoke.invoke(boundFunction, inputValues); + } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + } + + private static int bucketOfPaimonFunction(InternalRow internalRow, String... bucketColumns) { + List fields = ROW_TYPE.getFields(); + TableSchema schema = + new TableSchema( + 0, + fields, + RowType.currentHighestFieldId(fields), + Collections.emptyList(), + Collections.emptyList(), + ImmutableMap.of( + CoreOptions.BUCKET.key(), + String.valueOf(NUM_BUCKETS), + CoreOptions.BUCKET_KEY.key(), + Joiner.on(",").join(bucketColumns)), + ""); + + FixedBucketRowKeyExtractor fixedBucketRowKeyExtractor = + new FixedBucketRowKeyExtractor(schema); + fixedBucketRowKeyExtractor.setRecord(internalRow); + return fixedBucketRowKeyExtractor.bucket(); + } +}