Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit efca890

Browse files
committed
Fixed writing logic and conversion of complex types when reading
1 parent ce06954 commit efca890

File tree

9 files changed

+81
-48
lines changed

9 files changed

+81
-48
lines changed

src/main/scala/com/audienceproject/spark/dynamodb/catalyst/JavaConverter.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ object JavaConverter {
1515
elementType match {
1616
case ArrayType(innerType, _) => extractArray(row.getArray(index), innerType)
1717
case MapType(keyType, valueType, _) => extractMap(row.getMap(index), keyType, valueType)
18-
case StructType(fields) => mapStruct(row.getStruct(index, fields.length), fields)
18+
case StructType(fields) => extractStruct(row.getStruct(index, fields.length), fields)
1919
case StringType => row.getString(index)
2020
case _ => row.get(index, elementType)
2121
}
@@ -25,7 +25,7 @@ object JavaConverter {
2525
elementType match {
2626
case ArrayType(innerType, _) => array.toSeq[ArrayData](elementType).map(extractArray(_, innerType)).asJava
2727
case MapType(keyType, valueType, _) => array.toSeq[MapData](elementType).map(extractMap(_, keyType, valueType)).asJava
28-
case structType: StructType => array.toSeq[InternalRow](structType).map(mapStruct(_, structType.fields)).asJava
28+
case structType: StructType => array.toSeq[InternalRow](structType).map(extractStruct(_, structType.fields)).asJava
2929
case StringType => convertStringArray(array).asJava
3030
case _ => array.toSeq[Any](elementType).asJava
3131
}
@@ -38,21 +38,20 @@ object JavaConverter {
3838
val values = valueType match {
3939
case ArrayType(innerType, _) => map.valueArray().toSeq[ArrayData](valueType).map(extractArray(_, innerType))
4040
case MapType(innerKeyType, innerValueType, _) => map.valueArray().toSeq[MapData](valueType).map(extractMap(_, innerKeyType, innerValueType))
41-
case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(mapStruct(_, structType.fields))
41+
case structType: StructType => map.valueArray().toSeq[InternalRow](structType).map(extractStruct(_, structType.fields))
4242
case StringType => convertStringArray(map.valueArray())
4343
case _ => map.valueArray().toSeq[Any](valueType)
4444
}
4545
val kvPairs = for (i <- 0 until map.numElements()) yield keys(i) -> values(i)
4646
Map(kvPairs: _*).asJava
4747
}
4848

49-
def mapStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = {
49+
def extractStruct(row: InternalRow, fields: Seq[StructField]): util.Map[String, Any] = {
5050
val kvPairs = for (i <- 0 until row.numFields)
5151
yield fields(i).name -> extractRowValue(row, i, fields(i).dataType)
5252
Map(kvPairs: _*).asJava
5353
}
5454

55-
5655
def convertStringArray(array: ArrayData): Seq[String] =
5756
array.toSeq[UTF8String](StringType).map(_.toString)
5857

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
6363

6464
// Partitioning calculation.
6565
val numPartitions = parameters.get("readpartitions").map(_.toInt).getOrElse(
66-
(tableSize / maxPartitionBytes).toInt
66+
(tableSize / maxPartitionBytes).toInt max 1
6767
)
6868

6969
// Provisioned or on-demand throughput.
@@ -133,11 +133,11 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
133133
keySchema match {
134134
case KeySchema(hashKey, None) =>
135135
val hashKeyType = schema(hashKey).dataType
136-
item.withPrimaryKey(hashKey, row.get(hashKeyIndex, hashKeyType))
136+
item.withPrimaryKey(hashKey, JavaConverter.extractRowValue(row, hashKeyIndex, hashKeyType))
137137
case KeySchema(hashKey, Some(rangeKey)) =>
138-
val hashKeyType = schema(hashKey).dataType
139-
val rangeKeyType = schema(rangeKey).dataType
140-
item.withPrimaryKey(hashKey, row.get(hashKeyIndex, hashKeyType), rangeKey, row.get(rangeKeyIndex.get, rangeKeyType))
138+
val hashKeyValue = JavaConverter.extractRowValue(row, hashKeyIndex, schema(hashKey).dataType)
139+
val rangeKeyValue = JavaConverter.extractRowValue(row, rangeKeyIndex.get, schema(rangeKey).dataType)
140+
item.withPrimaryKey(hashKey, hashKeyValue, rangeKey, rangeKeyValue)
141141
}
142142

143143
// Map remaining columns.

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
5757

5858
// Partitioning calculation.
5959
val numPartitions = parameters.get("readpartitions").map(_.toInt).getOrElse(
60-
(indexSize / maxPartitionBytes).toInt
60+
(indexSize / maxPartitionBytes).toInt max 1
6161
)
6262

6363
// Provisioned or on-demand throughput.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.audienceproject.spark.dynamodb.datasource
2+
3+
import com.audienceproject.spark.dynamodb.connector.DynamoWritable
4+
import org.apache.spark.sql.catalyst.InternalRow
5+
import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage}
6+
import org.apache.spark.sql.types.StructType
7+
8+
import scala.collection.mutable.ArrayBuffer
9+
10+
class DynamoBatchWriter(batchSize: Int,
11+
connector: DynamoWritable,
12+
schema: StructType)
13+
extends DataWriter[InternalRow] {
14+
15+
private val buffer = new ArrayBuffer[InternalRow](batchSize)
16+
17+
override def write(record: InternalRow): Unit = {
18+
buffer += record.copy()
19+
if (buffer.size == batchSize) {
20+
flush()
21+
}
22+
}
23+
24+
override def commit(): WriterCommitMessage = {
25+
flush()
26+
new WriterCommitMessage {}
27+
}
28+
29+
override def abort(): Unit = {}
30+
31+
private def flush(): Unit = {
32+
connector.putItems(schema, buffer)
33+
buffer.clear()
34+
}
35+
36+
}

src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataSourceWriter.scala

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,44 +22,22 @@ package com.audienceproject.spark.dynamodb.datasource
2222

2323
import com.audienceproject.spark.dynamodb.connector.TableConnector
2424
import org.apache.spark.sql.catalyst.InternalRow
25-
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
25+
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage}
2626
import org.apache.spark.sql.types.StructType
2727

28-
import scala.collection.mutable.ArrayBuffer
29-
3028
class DynamoDataSourceWriter(parallelism: Int, parameters: Map[String, String], schema: StructType)
3129
extends DataSourceWriter {
3230

33-
private val tableName = parameters("tableName")
34-
private val batchSize = parameters.getOrElse("writeBatchSize", "25").toInt
31+
private val tableName = parameters("tablename")
32+
private val batchSize = parameters.getOrElse("writebatchsize", "25").toInt
3533

3634
private val dynamoConnector = new TableConnector(tableName, parallelism, parameters)
3735

38-
override def createWriterFactory(): DataWriterFactory[InternalRow] = new DataWriterFactory[InternalRow] {
39-
override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] =
40-
new DynamoDataWriter
41-
}
36+
override def createWriterFactory(): DataWriterFactory[InternalRow] =
37+
new DynamoWriterFactory(batchSize, dynamoConnector, schema)
4238

4339
override def commit(messages: Array[WriterCommitMessage]): Unit = {}
4440

4541
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
4642

47-
private class DynamoDataWriter extends DataWriter[InternalRow] {
48-
49-
private val buffer = new ArrayBuffer[InternalRow](batchSize)
50-
51-
override def write(record: InternalRow): Unit = {
52-
buffer += record
53-
if (buffer.size == batchSize) {
54-
dynamoConnector.putItems(schema, buffer)
55-
buffer.clear()
56-
}
57-
}
58-
59-
override def commit(): WriterCommitMessage = new WriterCommitMessage {}
60-
61-
override def abort(): Unit = {}
62-
63-
}
64-
6543
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.audienceproject.spark.dynamodb.datasource
2+
3+
import com.audienceproject.spark.dynamodb.connector.DynamoWritable
4+
import org.apache.spark.sql.catalyst.InternalRow
5+
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
6+
import org.apache.spark.sql.types.StructType
7+
8+
class DynamoWriterFactory(batchSize: Int,
9+
connector: DynamoWritable,
10+
schema: StructType)
11+
extends DataWriterFactory[InternalRow] {
12+
13+
override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] =
14+
new DynamoBatchWriter(batchSize, connector, schema)
15+
16+
}

src/main/scala/com/audienceproject/spark/dynamodb/datasource/TypeConversion.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
package com.audienceproject.spark.dynamodb.datasource
2222

2323
import com.amazonaws.services.dynamodbv2.document.Item
24-
import org.apache.spark.sql.Row
24+
import org.apache.spark.sql.catalyst.InternalRow
25+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
2526
import org.apache.spark.sql.types._
2627
import org.apache.spark.unsafe.types.UTF8String
2728

@@ -51,6 +52,8 @@ private[dynamodb] object TypeConversion {
5152
case _ => throw new IllegalArgumentException(s"Spark DataType '${sparkType.typeName}' could not be mapped to a corresponding DynamoDB data type.")
5253
}
5354

55+
private val stringConverter = (value: Any) => UTF8String.fromString(value.asInstanceOf[String])
56+
5457
private def convertValue(sparkType: DataType): Any => Any =
5558

5659
sparkType match {
@@ -71,7 +74,7 @@ private[dynamodb] object TypeConversion {
7174
case _ => null
7275
}
7376
case StringType => {
74-
case string: String => string
77+
case string: String => UTF8String.fromString(string)
7578
case _ => null
7679
}
7780
case BinaryType => {
@@ -94,18 +97,18 @@ private[dynamodb] object TypeConversion {
9497
}
9598

9699
private def extractArray(converter: Any => Any): Any => Any = {
97-
case list: java.util.List[_] => list.asScala.map(converter)
98-
case set: java.util.Set[_] => set.asScala.map(converter).toSeq
100+
case list: java.util.List[_] => new GenericArrayData(list.asScala.map(converter))
101+
case set: java.util.Set[_] => new GenericArrayData(set.asScala.map(converter).toSeq)
99102
case _ => null
100103
}
101104

102105
private def extractMap(converter: Any => Any): Any => Any = {
103-
case map: java.util.Map[_, _] => map.asScala.mapValues(converter)
106+
case map: java.util.Map[_, _] => ArrayBasedMapData(map, stringConverter, converter)
104107
case _ => null
105108
}
106109

107110
private def extractStruct(conversions: Seq[(String, Any => Any)]): Any => Any = {
108-
case map: java.util.Map[_, _] => Row.fromSeq(conversions.map({
111+
case map: java.util.Map[_, _] => InternalRow.fromSeq(conversions.map({
109112
case (name, conv) => conv(map.get(name))
110113
}))
111114
case _ => null

src/main/scala/com/audienceproject/spark/dynamodb/implicits.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ object implicits {
5353
}
5454

5555
private def getDynamoDBSource(tableName: String): DataFrameReader =
56-
reader.format("com.audienceproject.spark.dynamodb").option("tableName", tableName)
56+
reader.format("com.audienceproject.spark.dynamodb.datasource").option("tableName", tableName)
5757

5858
private def getColumnsAlias(dataFrame: DataFrame): DataFrame = {
5959
val columnsAlias = dataFrame.schema.collect({
@@ -70,7 +70,7 @@ object implicits {
7070
implicit class DynamoDBDataFrameWriter[T](writer: DataFrameWriter[T]) {
7171

7272
def dynamodb(tableName: String): Unit =
73-
writer.format("com.audienceproject.spark.dynamodb").option("tableName", tableName).save()
73+
writer.format("com.audienceproject.spark.dynamodb.datasource").option("tableName", tableName).save()
7474

7575
}
7676

src/test/scala/com/audienceproject/spark/dynamodb/DefaultSourceTest.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ import scala.collection.JavaConverters._
2929
class DefaultSourceTest extends AbstractInMemoryTest {
3030

3131
test("Table count is 9") {
32-
val count = spark.read.dynamodb("TestFruit").count()
33-
assert(count === 9)
32+
val count = spark.read.dynamodb("TestFruit")
33+
count.show()
34+
assert(count.count() === 9)
3435
}
3536

3637
test("Column sum is 27") {

0 commit comments

Comments
 (0)