Skip to content

Commit 4e49f2d

Browse files
authored
[spark] support spark batch write (#2277)
1 parent c10232e commit 4e49f2d

File tree

19 files changed

+1127
-12
lines changed

19 files changed

+1127
-12
lines changed

fluss-common/src/main/java/org/apache/fluss/row/TimestampLtz.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.io.Serializable;
2424
import java.time.Instant;
2525

26+
import static org.apache.fluss.utils.DateTimeUtils.MICROS_PER_MILLIS;
27+
import static org.apache.fluss.utils.DateTimeUtils.NANOS_PER_MICROS;
2628
import static org.apache.fluss.utils.Preconditions.checkArgument;
2729

2830
/**
@@ -39,10 +41,6 @@ public class TimestampLtz implements Comparable<TimestampLtz>, Serializable {
3941

4042
private static final long serialVersionUID = 1L;
4143

42-
public static final long MICROS_PER_MILLIS = 1000L;
43-
44-
public static final long NANOS_PER_MICROS = 1000L;
45-
4644
// this field holds the integral second and the milli-of-second.
4745
private final long millisecond;
4846

fluss-common/src/main/java/org/apache/fluss/row/TimestampNtz.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.time.LocalDateTime;
2626
import java.time.LocalTime;
2727

28+
import static org.apache.fluss.utils.DateTimeUtils.MICROS_PER_MILLIS;
29+
import static org.apache.fluss.utils.DateTimeUtils.NANOS_PER_MICROS;
2830
import static org.apache.fluss.utils.Preconditions.checkArgument;
2931

3032
/**
@@ -70,6 +72,20 @@ public int getNanoOfMillisecond() {
7072
return nanoOfMillisecond;
7173
}
7274

75+
/**
76+
* Creates an instance of {@link TimestampNtz} from microseconds.
77+
*
78+
* <p>The nanos-of-millisecond field will be set to zero.
79+
*
80+
* @param microseconds the number of microseconds since {@code 1970-01-01 00:00:00}; a negative
81+
* number is the number of microseconds before {@code 1970-01-01 00:00:00}
82+
*/
83+
public static TimestampNtz fromMicros(long microseconds) {
84+
long mills = Math.floorDiv(microseconds, MICROS_PER_MILLIS);
85+
long nanos = (microseconds - mills * MICROS_PER_MILLIS) * NANOS_PER_MICROS;
86+
return TimestampNtz.fromMillis(mills, (int) nanos);
87+
}
88+
7389
/**
7490
* Creates an instance of {@link TimestampNtz} from milliseconds.
7591
*

fluss-common/src/main/java/org/apache/fluss/utils/DateTimeUtils.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ public class DateTimeUtils {
4747
/** The julian date of the epoch, 1970-01-01. */
4848
public static final int EPOCH_JULIAN = 2440588;
4949

50+
/** The number of microseconds in a millisecond. */
51+
public static final long MICROS_PER_MILLIS = 1000L;
52+
53+
/** The number of nanoseconds in a microsecond. */
54+
public static final long NANOS_PER_MICROS = 1000L;
55+
5056
/** The number of milliseconds in a second. */
5157
private static final long MILLIS_PER_SECOND = 1000L;
5258

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkCatalog.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class SparkCatalog extends TableCatalog with SupportsFlussNamespaces with WithFl
4949

5050
override def loadTable(ident: Identifier): Table = {
5151
try {
52-
SparkTable(admin.getTableInfo(toTablePath(ident)).get())
52+
val tablePath = toTablePath(ident)
53+
SparkTable(tablePath, admin.getTableInfo(tablePath).get(), flussConfig)
5354
} catch {
5455
case e: ExecutionException if e.getCause.isInstanceOf[TableNotExistException] =>
5556
throw new NoSuchTableException(ident)

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkTable.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,24 @@
1717

1818
package org.apache.fluss.spark
1919

20-
import org.apache.fluss.metadata.TableInfo
20+
import org.apache.fluss.config.{Configuration => FlussConfiguration}
21+
import org.apache.fluss.metadata.{TableInfo, TablePath}
2122
import org.apache.fluss.spark.catalog.{AbstractSparkTable, SupportsFlussPartitionManagement}
23+
import org.apache.fluss.spark.write.{FlussAppendWriteBuilder, FlussUpsertWriteBuilder}
2224

23-
case class SparkTable(table: TableInfo)
24-
extends AbstractSparkTable(table)
25-
with SupportsFlussPartitionManagement {}
25+
import org.apache.spark.sql.connector.catalog.SupportsWrite
26+
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
27+
28+
case class SparkTable(tablePath: TablePath, tableInfo: TableInfo, flussConfig: FlussConfiguration)
29+
extends AbstractSparkTable(tableInfo)
30+
with SupportsFlussPartitionManagement
31+
with SupportsWrite {
32+
33+
override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = {
34+
if (tableInfo.getPrimaryKeys.isEmpty) {
35+
new FlussAppendWriteBuilder(tablePath, logicalWriteInfo.schema(), flussConfig)
36+
} else {
37+
new FlussUpsertWriteBuilder(tablePath, logicalWriteInfo.schema(), flussConfig)
38+
}
39+
}
40+
}

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalog/AbstractSparkTable.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.fluss.spark.catalog
1919

20-
import org.apache.fluss.metadata.TableInfo
20+
import org.apache.fluss.config.{Configuration => FlussConfiguration}
21+
import org.apache.fluss.metadata.{TableInfo, TablePath}
2122
import org.apache.fluss.spark.SparkConversions
2223

2324
import org.apache.spark.sql.connector.catalog.{Table, TableCapability}
@@ -39,5 +40,7 @@ abstract class AbstractSparkTable(tableInfo: TableInfo) extends Table {
3940

4041
override def schema(): StructType = _schema
4142

42-
override def capabilities(): util.Set[TableCapability] = Set.empty[TableCapability].asJava
43+
override def capabilities(): util.Set[TableCapability] = {
44+
Set(TableCapability.BATCH_WRITE).asJava
45+
}
4346
}

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalog/WithFlussAdmin.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ trait WithFlussAdmin extends AutoCloseable {
3333

3434
private var _connection: Connection = _
3535
private var _admin: Admin = _
36+
private var _flussConfig: FlussConfiguration = _
3637

3738
// TODO: init lake spark catalog
3839
protected var lakeCatalog: CatalogPlugin = _
@@ -43,10 +44,16 @@ trait WithFlussAdmin extends AutoCloseable {
4344
entry: util.Map.Entry[String, String] => flussConfigs.put(entry.getKey, entry.getValue)
4445
}
4546

46-
_connection = ConnectionFactory.createConnection(FlussConfiguration.fromMap(flussConfigs))
47+
_flussConfig = FlussConfiguration.fromMap(flussConfigs)
48+
_connection = ConnectionFactory.createConnection(_flussConfig)
4749
_admin = _connection.getAdmin
4850
}
4951

52+
protected def flussConfig: FlussConfiguration = {
53+
Preconditions.checkNotNull(_flussConfig, "Fluss Configuration is not initialized.")
54+
_flussConfig
55+
}
56+
5057
protected def admin: Admin = {
5158
Preconditions.checkNotNull(_admin, "Fluss Admin is not initialized.")
5259
_admin
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.fluss.spark.row
19+
20+
import org.apache.fluss.row.{BinaryString, Decimal, InternalArray => FlussInternalArray, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz}
21+
22+
import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData}
23+
import org.apache.spark.sql.types.{ArrayType => SparkArrayType, DataType => SparkDataType, StructType}
24+
25+
/** Wraps a Spark [[SparkArrayData]] as a Fluss [[FlussInternalArray]]. */
26+
class SparkAsFlussArray(arrayData: SparkArrayData, elementType: SparkDataType)
27+
extends FlussInternalArray
28+
with Serializable {
29+
30+
/** Returns the number of elements in this array. */
31+
override def size(): Int = arrayData.numElements()
32+
33+
override def toBooleanArray: Array[Boolean] = arrayData.toBooleanArray()
34+
35+
override def toByteArray: Array[Byte] = arrayData.toByteArray()
36+
37+
override def toShortArray: Array[Short] = arrayData.toShortArray()
38+
39+
override def toIntArray: Array[Int] = arrayData.toIntArray()
40+
41+
override def toLongArray: Array[Long] = arrayData.toLongArray()
42+
43+
override def toFloatArray: Array[Float] = arrayData.toFloatArray()
44+
45+
override def toDoubleArray: Array[Double] = arrayData.toDoubleArray()
46+
47+
/** Returns true if the element is null at the given position. */
48+
override def isNullAt(pos: Int): Boolean = arrayData.isNullAt(pos)
49+
50+
/** Returns the boolean value at the given position. */
51+
override def getBoolean(pos: Int): Boolean = arrayData.getBoolean(pos)
52+
53+
/** Returns the byte value at the given position. */
54+
override def getByte(pos: Int): Byte = arrayData.getByte(pos)
55+
56+
/** Returns the short value at the given position. */
57+
override def getShort(pos: Int): Short = arrayData.getShort(pos)
58+
59+
/** Returns the integer value at the given position. */
60+
override def getInt(pos: Int): Int = arrayData.getInt(pos)
61+
62+
/** Returns the long value at the given position. */
63+
override def getLong(pos: Int): Long = arrayData.getLong(pos)
64+
65+
/** Returns the float value at the given position. */
66+
override def getFloat(pos: Int): Float = arrayData.getFloat(pos)
67+
68+
/** Returns the double value at the given position. */
69+
override def getDouble(pos: Int): Double = arrayData.getDouble(pos)
70+
71+
/** Returns the string value at the given position with fixed length. */
72+
override def getChar(pos: Int, length: Int): BinaryString =
73+
BinaryString.fromBytes(arrayData.getUTF8String(pos).getBytes)
74+
75+
/** Returns the string value at the given position. */
76+
override def getString(pos: Int): BinaryString =
77+
BinaryString.fromBytes(arrayData.getUTF8String(pos).getBytes)
78+
79+
/**
80+
* Returns the decimal value at the given position.
81+
*
82+
* <p>The precision and scale are required to determine whether the decimal value was stored in a
83+
* compact representation (see [[Decimal]]).
84+
*/
85+
override def getDecimal(pos: Int, precision: Int, scale: Int): Decimal = {
86+
val sparkDecimal = arrayData.getDecimal(pos, precision, scale)
87+
if (sparkDecimal.precision <= org.apache.spark.sql.types.Decimal.MAX_LONG_DIGITS)
88+
Decimal.fromUnscaledLong(
89+
sparkDecimal.toUnscaledLong,
90+
sparkDecimal.precision,
91+
sparkDecimal.scale)
92+
else
93+
Decimal.fromBigDecimal(
94+
sparkDecimal.toJavaBigDecimal,
95+
sparkDecimal.precision,
96+
sparkDecimal.scale)
97+
}
98+
99+
/**
100+
* Returns the timestamp value at the given position.
101+
*
102+
* <p>The precision is required to determine whether the timestamp value was stored in a compact
103+
* representation (see [[TimestampNtz]]).
104+
*/
105+
override def getTimestampNtz(pos: Int, precision: Int): TimestampNtz =
106+
TimestampNtz.fromMicros(arrayData.getLong(pos))
107+
108+
/**
109+
* Returns the timestamp value at the given position.
110+
*
111+
* <p>The precision is required to determine whether the timestamp value was stored in a compact
112+
* representation (see [[TimestampLtz]]).
113+
*/
114+
override def getTimestampLtz(pos: Int, precision: Int): TimestampLtz =
115+
TimestampLtz.fromEpochMicros(arrayData.getLong(pos))
116+
117+
/** Returns the binary value at the given position with fixed length. */
118+
override def getBinary(pos: Int, length: Int): Array[Byte] = arrayData.getBinary(pos)
119+
120+
/** Returns the binary value at the given position. */
121+
override def getBytes(pos: Int): Array[Byte] = arrayData.getBinary(pos)
122+
123+
/** Returns the array value at the given position. */
124+
override def getArray(pos: Int) = new SparkAsFlussArray(
125+
arrayData.getArray(pos),
126+
elementType.asInstanceOf[SparkArrayType].elementType)
127+
128+
/** Returns the row value at the given position. */
129+
override def getRow(pos: Int, numFields: Int): FlussInternalRow =
130+
new SparkAsFlussRow(elementType.asInstanceOf[StructType])
131+
.replace(arrayData.getStruct(pos, numFields))
132+
}

0 commit comments

Comments
 (0)