Skip to content

Commit edcfb7a

Browse files
committed
[spark] support write log/pk table
1 parent 920809f commit edcfb7a

File tree

16 files changed

+1163
-8
lines changed

16 files changed

+1163
-8
lines changed

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkCatalog.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class SparkCatalog extends TableCatalog with SupportsFlussNamespaces with WithFl
4949

5050
override def loadTable(ident: Identifier): Table = {
5151
try {
52-
SparkTable(admin.getTableInfo(toTablePath(ident)).get())
52+
val tablePath = toTablePath(ident)
53+
SparkTable(tablePath, admin.getTableInfo(tablePath).get(), flussConfig)
5354
} catch {
5455
case e: ExecutionException if e.getCause.isInstanceOf[TableNotExistException] =>
5556
throw new NoSuchTableException(ident)

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkTable.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,24 @@
1717

1818
package org.apache.fluss.spark
1919

20-
import org.apache.fluss.metadata.TableInfo
20+
import org.apache.fluss.config.{Configuration => FlussConfiguration}
21+
import org.apache.fluss.metadata.{TableInfo, TablePath}
2122
import org.apache.fluss.spark.catalog.{AbstractSparkTable, SupportsFlussPartitionManagement}
23+
import org.apache.fluss.spark.write.{FlussAppendWriteBuilder, FlussUpsertWriteBuilder}
2224

23-
case class SparkTable(table: TableInfo)
24-
extends AbstractSparkTable(table)
25-
with SupportsFlussPartitionManagement {}
25+
import org.apache.spark.sql.connector.catalog.SupportsWrite
26+
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
27+
28+
case class SparkTable(tablePath: TablePath, tableInfo: TableInfo, flussConfig: FlussConfiguration)
29+
extends AbstractSparkTable(tableInfo)
30+
with SupportsFlussPartitionManagement
31+
with SupportsWrite {
32+
33+
override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = {
34+
if (tableInfo.getPrimaryKeys.isEmpty) {
35+
new FlussAppendWriteBuilder(tablePath, logicalWriteInfo.schema(), flussConfig)
36+
} else {
37+
new FlussUpsertWriteBuilder(tablePath, logicalWriteInfo.schema(), flussConfig)
38+
}
39+
}
40+
}

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalog/AbstractSparkTable.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.fluss.spark.catalog
1919

20-
import org.apache.fluss.metadata.TableInfo
20+
import org.apache.fluss.config.{Configuration => FlussConfiguration}
21+
import org.apache.fluss.metadata.{TableInfo, TablePath}
2122
import org.apache.fluss.spark.SparkConversions
2223

2324
import org.apache.spark.sql.connector.catalog.{Table, TableCapability}
@@ -39,5 +40,7 @@ abstract class AbstractSparkTable(tableInfo: TableInfo) extends Table {
3940

4041
override def schema(): StructType = _schema
4142

42-
override def capabilities(): util.Set[TableCapability] = Set.empty[TableCapability].asJava
43+
override def capabilities(): util.Set[TableCapability] = {
44+
Set(TableCapability.BATCH_WRITE).asJava
45+
}
4346
}

fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalog/WithFlussAdmin.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ trait WithFlussAdmin extends AutoCloseable {
3333

3434
private var _connection: Connection = _
3535
private var _admin: Admin = _
36+
private var _flussConfig: FlussConfiguration = _
3637

3738
// TODO: init lake spark catalog
3839
protected var lakeCatalog: CatalogPlugin = _
@@ -43,10 +44,16 @@ trait WithFlussAdmin extends AutoCloseable {
4344
entry: util.Map.Entry[String, String] => flussConfigs.put(entry.getKey, entry.getValue)
4445
}
4546

46-
_connection = ConnectionFactory.createConnection(FlussConfiguration.fromMap(flussConfigs))
47+
_flussConfig = FlussConfiguration.fromMap(flussConfigs)
48+
_connection = ConnectionFactory.createConnection(_flussConfig)
4749
_admin = _connection.getAdmin
4850
}
4951

52+
protected def flussConfig: FlussConfiguration = {
53+
Preconditions.checkNotNull(_flussConfig, "Fluss Configuration is not initialized.")
54+
_flussConfig
55+
}
56+
5057
protected def admin: Admin = {
5158
Preconditions.checkNotNull(_admin, "Fluss Admin is not initialized.")
5259
_admin
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.fluss.spark.row
19+
20+
import org.apache.fluss.row.{BinaryString, Decimal, InternalArray => FlussInternalArray, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz}
21+
22+
import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, SparkDateTimeUtils}
23+
import org.apache.spark.sql.types.{ArrayType => SparkArrayType, DataType => SparkDataType, StructType}
24+
25+
/** Wraps a Spark [[SparkArrayData]] as a Fluss [[FlussInternalArray]]. */
26+
class SparkAsFlussArray(arrayData: SparkArrayData, elementType: SparkDataType)
27+
extends FlussInternalArray {
28+
29+
/** Returns the number of elements in this array. */
30+
override def size(): Int = arrayData.numElements()
31+
32+
override def toBooleanArray: Array[Boolean] = arrayData.toBooleanArray()
33+
34+
override def toByteArray: Array[Byte] = arrayData.toByteArray()
35+
36+
override def toShortArray: Array[Short] = arrayData.toShortArray()
37+
38+
override def toIntArray: Array[Int] = arrayData.toIntArray()
39+
40+
override def toLongArray: Array[Long] = arrayData.toLongArray()
41+
42+
override def toFloatArray: Array[Float] = arrayData.toFloatArray()
43+
44+
override def toDoubleArray: Array[Double] = arrayData.toDoubleArray()
45+
46+
/** Returns true if the element is null at the given position. */
47+
override def isNullAt(pos: Int): Boolean = arrayData.isNullAt(pos)
48+
49+
/** Returns the boolean value at the given position. */
50+
override def getBoolean(pos: Int): Boolean = arrayData.getBoolean(pos)
51+
52+
/** Returns the byte value at the given position. */
53+
override def getByte(pos: Int): Byte = arrayData.getByte(pos)
54+
55+
/** Returns the short value at the given position. */
56+
override def getShort(pos: Int): Short = arrayData.getShort(pos)
57+
58+
/** Returns the integer value at the given position. */
59+
override def getInt(pos: Int): Int = arrayData.getInt(pos)
60+
61+
/** Returns the long value at the given position. */
62+
override def getLong(pos: Int): Long = arrayData.getLong(pos)
63+
64+
/** Returns the float value at the given position. */
65+
override def getFloat(pos: Int): Float = arrayData.getFloat(pos)
66+
67+
/** Returns the double value at the given position. */
68+
override def getDouble(pos: Int): Double = arrayData.getDouble(pos)
69+
70+
/** Returns the string value at the given position with fixed length. */
71+
override def getChar(pos: Int, length: Int): BinaryString =
72+
BinaryString.fromBytes(arrayData.getUTF8String(pos).getBytes)
73+
74+
/** Returns the string value at the given position. */
75+
override def getString(pos: Int): BinaryString =
76+
BinaryString.fromBytes(arrayData.getUTF8String(pos).getBytes)
77+
78+
/**
79+
* Returns the decimal value at the given position.
80+
*
81+
* <p>The precision and scale are required to determine whether the decimal value was stored in a
82+
* compact representation (see {@link Decimal}).
83+
*/
84+
override def getDecimal(pos: Int, precision: Int, scale: Int): Decimal = {
85+
val sparkDecimal = arrayData.getDecimal(pos, precision, scale)
86+
if (sparkDecimal.precision <= org.apache.spark.sql.types.Decimal.MAX_LONG_DIGITS)
87+
Decimal.fromUnscaledLong(
88+
sparkDecimal.toUnscaledLong,
89+
sparkDecimal.precision,
90+
sparkDecimal.scale)
91+
else
92+
Decimal.fromBigDecimal(
93+
sparkDecimal.toJavaBigDecimal,
94+
sparkDecimal.precision,
95+
sparkDecimal.scale)
96+
}
97+
98+
/**
99+
* Returns the timestamp value at the given position.
100+
*
101+
* <p>The precision is required to determine whether the timestamp value was stored in a compact
102+
* representation (see {@link TimestampNtz}).
103+
*/
104+
override def getTimestampNtz(pos: Int, precision: Int): TimestampNtz =
105+
TimestampNtz.fromMillis(SparkDateTimeUtils.microsToMillis(arrayData.getLong(pos)))
106+
107+
/**
108+
* Returns the timestamp value at the given position.
109+
*
110+
* <p>The precision is required to determine whether the timestamp value was stored in a compact
111+
* representation (see {@link TimestampLtz}).
112+
*/
113+
override def getTimestampLtz(pos: Int, precision: Int): TimestampLtz =
114+
TimestampLtz.fromEpochMicros(arrayData.getLong(pos))
115+
116+
/** Returns the binary value at the given position with fixed length. */
117+
override def getBinary(pos: Int, length: Int): Array[Byte] = arrayData.getBinary(pos)
118+
119+
/** Returns the binary value at the given position. */
120+
override def getBytes(pos: Int): Array[Byte] = arrayData.getBinary(pos)
121+
122+
/** Returns the array value at the given position. */
123+
override def getArray(pos: Int) = new SparkAsFlussArray(
124+
arrayData.getArray(pos),
125+
elementType.asInstanceOf[SparkArrayType].elementType)
126+
127+
/** Returns the row value at the given position. */
128+
override def getRow(pos: Int, numFields: Int): FlussInternalRow =
129+
new SparkAsFlussRow(elementType.asInstanceOf[StructType])
130+
.replace(arrayData.getStruct(pos, numFields))
131+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.fluss.spark.row
19+
20+
import org.apache.fluss.row.{BinaryString, Decimal, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz}
21+
22+
import org.apache.spark.sql.catalyst.{InternalRow => SparkInternalRow}
23+
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils
24+
import org.apache.spark.sql.types.StructType
25+
26+
/** Wraps a Spark [[SparkInternalRow]] as a Fluss [[FlussInternalRow]]. */
27+
class SparkAsFlussRow(schema: StructType) extends FlussInternalRow with Serializable {
28+
29+
val fieldCount: Int = schema.length
30+
31+
var row: SparkInternalRow = _
32+
33+
def replace(row: SparkInternalRow): SparkAsFlussRow = {
34+
this.row = row
35+
this
36+
}
37+
38+
/**
39+
* Returns the number of fields in this row.
40+
*
41+
* <p>The number does not include {@link ChangeType}. It is kept separately.
42+
*/
43+
override def getFieldCount: Int = fieldCount
44+
45+
/** Returns true if the element is null at the given position. */
46+
override def isNullAt(pos: Int): Boolean = row.isNullAt(pos)
47+
48+
/** Returns the boolean value at the given position. */
49+
override def getBoolean(pos: Int): Boolean = row.getBoolean(pos)
50+
51+
/** Returns the byte value at the given position. */
52+
override def getByte(pos: Int): Byte = row.getByte(pos)
53+
54+
/** Returns the short value at the given position. */
55+
override def getShort(pos: Int): Short = row.getShort(pos)
56+
57+
/** Returns the integer value at the given position. */
58+
override def getInt(pos: Int): Int = row.getInt(pos)
59+
60+
/** Returns the long value at the given position. */
61+
override def getLong(pos: Int): Long = row.getLong(pos)
62+
63+
/** Returns the float value at the given position. */
64+
override def getFloat(pos: Int): Float = row.getFloat(pos)
65+
66+
/** Returns the double value at the given position. */
67+
override def getDouble(pos: Int): Double = row.getDouble(pos)
68+
69+
/** Returns the string value at the given position with fixed length. */
70+
override def getChar(pos: Int, length: Int): BinaryString =
71+
BinaryString.fromString(row.getUTF8String(pos).toString)
72+
73+
/** Returns the string value at the given position. */
74+
override def getString(pos: Int): BinaryString = BinaryString.fromString(row.getString(pos))
75+
76+
/**
77+
* Returns the decimal value at the given position.
78+
*
79+
* <p>The precision and scale are required to determine whether the decimal value was stored in a
80+
* compact representation (see {@link Decimal}).
81+
*/
82+
override def getDecimal(pos: Int, precision: Int, scale: Int): Decimal = {
83+
val sparkDecimal = row.getDecimal(pos, precision, scale)
84+
if (sparkDecimal.precision <= org.apache.spark.sql.types.Decimal.MAX_LONG_DIGITS)
85+
Decimal.fromUnscaledLong(
86+
sparkDecimal.toUnscaledLong,
87+
sparkDecimal.precision,
88+
sparkDecimal.scale)
89+
else
90+
Decimal.fromBigDecimal(
91+
sparkDecimal.toJavaBigDecimal,
92+
sparkDecimal.precision,
93+
sparkDecimal.scale)
94+
}
95+
96+
/**
97+
* Returns the timestamp value at the given position.
98+
*
99+
* <p>The precision is required to determine whether the timestamp value was stored in a compact
100+
* representation (see {@link TimestampNtz}).
101+
*/
102+
override def getTimestampNtz(pos: Int, precision: Int): TimestampNtz =
103+
TimestampNtz.fromMillis(SparkDateTimeUtils.microsToMillis(row.getLong(pos)))
104+
105+
/**
106+
* Returns the timestamp value at the given position.
107+
*
108+
* <p>The precision is required to determine whether the timestamp value was stored in a compact
109+
* representation (see {@link TimestampLtz}).
110+
*/
111+
override def getTimestampLtz(pos: Int, precision: Int): TimestampLtz =
112+
TimestampLtz.fromEpochMicros(row.getLong(pos))
113+
114+
/** Returns the binary value at the given position with fixed length. */
115+
override def getBinary(pos: Int, length: Int): Array[Byte] = row.getBinary(pos)
116+
117+
/** Returns the binary value at the given position. */
118+
override def getBytes(pos: Int): Array[Byte] = row.getBinary(pos)
119+
120+
/** Returns the array value at the given position. */
121+
override def getArray(pos: Int) =
122+
new SparkAsFlussArray(row.getArray(pos), schema.fields(pos).dataType)
123+
124+
/** Returns the row value at the given position. */
125+
override def getRow(pos: Int, numFields: Int): FlussInternalRow =
126+
new SparkAsFlussRow(schema.fields(pos).dataType.asInstanceOf[StructType])
127+
.replace(row.getStruct(pos, numFields))
128+
129+
}

0 commit comments

Comments
 (0)