Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public static InternalArray copyArray(InternalArray from, DataType eleType) {
return new GenericArray(newArray);
}

private static InternalMap copyMap(InternalMap map, DataType keyType, DataType valueType) {
public static InternalMap copyMap(InternalMap map, DataType keyType, DataType valueType) {
if (map instanceof BinaryMap) {
return ((BinaryMap) map).copy();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ object DataConverter {
}

def toSparkMap(flussMap: FlussInternalMap, mapType: FlussMapType): SparkMapData = {
// TODO: support map type in fluss-spark
throw new UnsupportedOperationException()
new FlussAsSparkMap(mapType).replace(flussMap)
}

def toSparkInternalRow(flussRow: FlussInternalRow, rowType: RowType): SparkInteralRow = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData
import org.apache.spark.sql.types.{DataType => SparkDataType, Decimal => SparkDecimal}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

/** Wraps a Fluss [[FlussInternalArray]] as a Spark [[SparkArrayData]]. */
class FlussAsSparkArray(elementType: FlussDataType) extends SparkArrayData {

var flussArray: FlussInternalArray = _
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.fluss.spark.row

import org.apache.fluss.row.{InternalMap => FlussInternalMap}
import org.apache.fluss.types.{MapType => FlussMapType}
import org.apache.fluss.utils.InternalRowUtils

import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData => SparkMapData}

/** Wraps a Fluss [[FlussInternalMap]] as a Spark [[SparkMapData]]. */
class FlussAsSparkMap(mapType: FlussMapType) extends SparkMapData {

var flussMap: FlussInternalMap = _

def replace(map: FlussInternalMap): SparkMapData = {
this.flussMap = map
this
}

override def numElements(): Int = flussMap.size()

override def copy(): SparkMapData = {
new FlussAsSparkMap(mapType)
.replace(InternalRowUtils.copyMap(flussMap, mapType.getKeyType, mapType.getValueType))
}

override def keyArray(): SparkArrayData = {
val keyType = mapType.getKeyType
new FlussAsSparkArray(keyType).replace(flussMap.keyArray())
}

override def valueArray(): SparkArrayData = {
val valueType = mapType.getValueType
new FlussAsSparkArray(valueType).replace(flussMap.valueArray())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.fluss.spark.row

import org.apache.fluss.row.{InternalRow => FlussInternalRow}
import org.apache.fluss.types.{ArrayType => FlussArrayType, BinaryType => FlussBinaryType, LocalZonedTimestampType, RowType, TimestampType}
import org.apache.fluss.types.{ArrayType => FlussArrayType, BinaryType => FlussBinaryType, LocalZonedTimestampType, MapType => FlussMapType, RowType, TimestampType}
import org.apache.fluss.utils.InternalRowUtils

import org.apache.spark.sql.catalyst.{InternalRow => SparkInteralRow}
Expand All @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData
import org.apache.spark.sql.types.{DataType => SparkDataType, Decimal => SparkDecimal}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

/** Wraps a Fluss [[FlussInternalRow]] as a Spark [[SparkInteralRow]]. */
class FlussAsSparkRow(rowType: RowType) extends SparkInteralRow {

val fieldCount: Int = rowType.getFieldCount
Expand Down Expand Up @@ -104,8 +105,9 @@ class FlussAsSparkRow(rowType: RowType) extends SparkInteralRow {
}

override def getMap(ordinal: Int): SparkMapData = {
// TODO: support map type in fluss-spark
throw new UnsupportedOperationException()
val mapType = rowType.getTypeAt(ordinal).asInstanceOf[FlussMapType]
val flussMap = row.getMap(ordinal)
DataConverter.toSparkMap(flussMap, mapType)
}

override def get(ordinal: Int, dataType: SparkDataType): AnyRef = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.fluss.spark.row
import org.apache.fluss.row.{BinaryString, Decimal, InternalArray => FlussInternalArray, InternalMap, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz}

import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData}
import org.apache.spark.sql.types.{ArrayType => SparkArrayType, DataType => SparkDataType, StructType}
import org.apache.spark.sql.types.{ArrayType => SparkArrayType, DataType => SparkDataType, MapType => SparkMapType, StructType}

/** Wraps a Spark [[SparkArrayData]] as a Fluss [[FlussInternalArray]]. */
class SparkAsFlussArray(arrayData: SparkArrayData, elementType: SparkDataType)
Expand Down Expand Up @@ -125,13 +125,14 @@ class SparkAsFlussArray(arrayData: SparkArrayData, elementType: SparkDataType)
arrayData.getArray(pos),
elementType.asInstanceOf[SparkArrayType].elementType)

/** Returns the map value at the given position. */
override def getMap(pos: Int): InternalMap = {
val mapType = elementType.asInstanceOf[SparkMapType]
SparkAsFlussMap(arrayData.getMap(pos), mapType)
}

/** Returns the row value at the given position. */
override def getRow(pos: Int, numFields: Int): FlussInternalRow =
new SparkAsFlussRow(elementType.asInstanceOf[StructType])
.replace(arrayData.getStruct(pos, numFields))

/** Returns the map value at the given position. */
override def getMap(pos: Int): InternalMap = {
throw new UnsupportedOperationException()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.fluss.spark.row

import org.apache.fluss.row.{InternalArray => FlussInternalArray, InternalMap => FlussInternalMap}

import org.apache.spark.sql.catalyst.util.{MapData => SparkMapData}
import org.apache.spark.sql.types.{DataType => SparkDataType, MapType => SparkMapType}

/** Wraps a Spark [[SparkMapData]] as a Fluss [[FlussInternalMap]]. */
class SparkAsFlussMap(mapData: SparkMapData, keyType: SparkDataType, valueType: SparkDataType)
extends FlussInternalMap
with Serializable {

/** Returns the number of key-value mappings in this map. */
override def size(): Int = mapData.numElements()

/**
* Returns an array view of the keys contained in this map.
*
* <p>A key-value pair has the same index in the key array and value array.
*/
override def keyArray(): FlussInternalArray = {
new SparkAsFlussArray(mapData.keyArray(), keyType)
}

/**
* Returns an array view of the values contained in this map.
*
* <p>A key-value pair has the same index in the key array and value array.
*/
override def valueArray(): FlussInternalArray = {
new SparkAsFlussArray(mapData.valueArray(), valueType)
}
}

object SparkAsFlussMap {
def apply(mapData: SparkMapData, mapType: SparkMapType): SparkAsFlussMap =
new SparkAsFlussMap(mapData, mapType.keyType, mapType.valueType)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.fluss.spark.row
import org.apache.fluss.row.{BinaryString, Decimal, InternalMap, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz}

import org.apache.spark.sql.catalyst.{InternalRow => SparkInternalRow}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{MapType => SparkMapType, StructType}

/** Wraps a Spark [[SparkInternalRow]] as a Fluss [[FlussInternalRow]]. */
class SparkAsFlussRow(schema: StructType) extends FlussInternalRow with Serializable {
Expand Down Expand Up @@ -127,6 +127,8 @@ class SparkAsFlussRow(schema: StructType) extends FlussInternalRow with Serializ

/** Returns the map value at the given position. */
override def getMap(pos: Int): InternalMap = {
throw new UnsupportedOperationException()
val sparkMapData = row.getMap(pos)
val mapType = schema.fields(pos).dataType.asInstanceOf[SparkMapType]
SparkAsFlussMap(sparkMapData, mapType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,29 +213,31 @@ class SparkLogTableReadTest extends FlussSparkTestBase {

test("Spark Read: nested data types table") {
withTable("t") {
// TODO: support map type
sql(s"""
|CREATE TABLE $DEFAULT_DATABASE.t (
|id INT,
|arr ARRAY<INT>,
|map MAP<STRING, INT>,
|struct_col STRUCT<col1: INT, col2: STRING>
|)""".stripMargin)

sql(s"""
|INSERT INTO $DEFAULT_DATABASE.t VALUES
|(1, ARRAY(1, 2, 3), STRUCT(100, 'nested_value')),
|(2, ARRAY(7, 8, 9), STRUCT(200, 'nested_value2'))
|(1, ARRAY(1, 2, 3), MAP("k1", 111, "k2", 222), STRUCT(100, 'nested_value')),
|(2, ARRAY(7, 8, 9), MAP("k1", 333, "k2", 444), STRUCT(200, 'nested_value2'))
|""".stripMargin)

checkAnswer(
sql(s"SELECT * FROM $DEFAULT_DATABASE.t ORDER BY id"),
Row(
1,
Seq(1, 2, 3),
Map("k1" -> 111, "k2" -> 222),
Row(100, "nested_value")
) :: Row(
2,
Seq(7, 8, 9),
Map("k1" -> 333, "k2" -> 444),
Row(200, "nested_value2")
) :: Nil
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class SparkWriteTest extends FlussSparkTestBase {
| 1234567.89, 12345678900987654321.12,
| "test",
| TO_TIMESTAMP('2025-12-31 10:00:00', 'yyyy-MM-dd kk:mm:ss'),
| array(11.11F, 22.22F), struct(123L, "apache fluss")
| array(11.11F, 22.22F),
| map("k1", 111, "k2", 222),
| struct(123L, "apache fluss")
|)
|""".stripMargin)

Expand All @@ -56,7 +58,7 @@ class SparkWriteTest extends FlussSparkTestBase {
assertThat(rows.length).isEqualTo(1)

val row = rows.head
assertThat(row.getFieldCount).isEqualTo(13)
assertThat(row.getFieldCount).isEqualTo(14)
assertThat(row.getBoolean(0)).isEqualTo(true)
assertThat(row.getByte(1)).isEqualTo(1.toByte)
assertThat(row.getShort(2)).isEqualTo(10.toShort)
Expand All @@ -71,7 +73,12 @@ class SparkWriteTest extends FlussSparkTestBase {
assertThat(row.getTimestampLtz(10, 6).toInstant)
.isEqualTo(Timestamp.valueOf("2025-12-31 10:00:00.0").toInstant)
assertThat(row.getArray(11).toFloatArray).containsExactly(Array(11.11f, 22.22f): _*)
val nestedRow = row.getRow(12, 2)
val mapData = row.getMap(12)
assertThat(mapData.size()).isEqualTo(2)
assertThat(mapData.keyArray().getString(0).toString).isEqualTo("k1")
assertThat(mapData.keyArray().getString(1).toString).isEqualTo("k2")
assertThat(mapData.valueArray().toIntArray).containsExactly(Array(111, 222): _*)
val nestedRow = row.getRow(13, 2)
assertThat(nestedRow.getFieldCount).isEqualTo(2)
assertThat(nestedRow.getLong(0)).isEqualTo(123L)
assertThat(nestedRow.getString(1).toString).isEqualTo("apache fluss")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.fluss.spark.row

import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericRow, TimestampLtz, TimestampNtz}
import org.apache.fluss.types.{ArrayType, CharType, DataTypes, DecimalType, LocalZonedTimestampType, RowType, TimestampType}
import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericMap, GenericRow, TimestampLtz, TimestampNtz}
import org.apache.fluss.types.{ArrayType, CharType, DataTypes, DecimalType, LocalZonedTimestampType, MapType, RowType, TimestampType}

import org.apache.spark.sql.types.{Decimal => SparkDecimal}
import org.apache.spark.unsafe.types.UTF8String
import org.assertj.core.api.Assertions.assertThat
import org.scalatest.funsuite.AnyFunSuite

import scala.collection.JavaConverters._

class DataConverterTest extends AnyFunSuite {

test("toSparkObject: null value") {
Expand Down Expand Up @@ -274,6 +276,25 @@ class DataConverterTest extends AnyFunSuite {
assertThat(result.asInstanceOf[Long]).isEqualTo(2000000000L) // microseconds
}

test("toSparkMap: Map type") {
val flussMap = new GenericMap(
Map(
BinaryString.fromString("a") -> Integer.valueOf(1),
BinaryString.fromString("b") -> Integer.valueOf(2)).asJava)
val mapType = new MapType(DataTypes.STRING, DataTypes.INT)
val sparkMap = DataConverter.toSparkMap(flussMap, mapType)
assertThat(sparkMap.numElements()).isEqualTo(2)

val keyArray = sparkMap.keyArray()
assertThat(keyArray.numElements()).isEqualTo(2)
assertThat(keyArray.getUTF8String(0).toString).isEqualTo("a")
assertThat(keyArray.getUTF8String(1).toString).isEqualTo("b")
val valueArray = sparkMap.valueArray()
assertThat(valueArray.numElements()).isEqualTo(2)
assertThat(valueArray.getInt(0)).isEqualTo(1)
assertThat(valueArray.getInt(1)).isEqualTo(2)
}

test("toSparkObject: ROW type") {
val rowType = RowType
.builder()
Expand All @@ -288,10 +309,4 @@ class DataConverterTest extends AnyFunSuite {
assertThat(result).isNotNull()
assertThat(result.asInstanceOf[FlussAsSparkRow].getInt(0)).isEqualTo(42)
}

test("toSparkMap: unsupported") {
assertThrows[UnsupportedOperationException] {
DataConverter.toSparkMap(null, null)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,17 @@ class FlussAsSparkArrayTest extends AnyFunSuite {
assertThat(sparkInnerArray2.getInt(2)).isEqualTo(6)
}

test("getMap: unsupported operation") {
test("getMap: read map array") {
val mapType = DataTypes.MAP(DataTypes.INT, DataTypes.STRING)
val flussArray = GenericArray.of(new GenericMap(Map(1 -> "map").asJava))
val innerMap =
new GenericMap(Map(Integer.valueOf(1) -> BinaryString.fromString("value1")).asJava)
val flussArray = new GenericArray(Array[Object](innerMap))
val sparkArray = new FlussAsSparkArray(mapType).replace(flussArray)

assertThrows[UnsupportedOperationException] {
sparkArray.getMap(0)
}
val sparkMap = sparkArray.getMap(0)
assertThat(sparkMap.numElements()).isEqualTo(1)
assertThat(sparkMap.keyArray().getInt(0)).isEqualTo(1)
assertThat(sparkMap.valueArray().getUTF8String(0).toString).isEqualTo("value1")
}

test("getInterval: unsupported operation") {
Expand Down
Loading