Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit 8a5e5c9

Browse files
committed
Cleaned up updateItems function and changed write relation to always use defaultParallelism for numPartitions
1 parent 000ddf4 commit 8a5e5c9

File tree

9 files changed

+97
-102
lines changed

9 files changed

+97
-102
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ The following parameters can be set as options on the Spark reader object before
6161

6262
The following parameters can be set as options on the Spark writer object before saving.
6363

64-
- `writePartitions` number of partitions to split the given DataFrame into when writing to DynamoDB. Set to `skip` to avoid repartitioning the DataFrame before writing. Defaults to `sparkContext.defaultParallelism`
6564
- `writeBatchSize` number of items to send per call to DynamoDB BatchWriteItem. Default 25.
66-
- `update` if true writes will be using UpdateItem on keys rather than BatchWriteItem. Default false
65+
- `targetCapacity` fraction of provisioned write capacity on the table to consume for writing or updating. Default 1 (i.e. 100% capacity).
66+
- `update` if true items will be written using UpdateItem on keys rather than BatchWriteItem. Default false.
6767

6868
## Running Unit Tests
6969
The unit tests are dependent on the AWS DynamoDBLocal client, which in turn is dependent on [sqlite4java](https://bitbucket.org/almworks/sqlite4java/src/master/). I had some problems running this on OSX, so I had to put the library directly in the /lib folder, as graciously explained in [this Stack Overflow answer](https://stackoverflow.com/questions/34137043/amazon-dynamodb-local-unknown-error-exception-or-failure/35353377#35353377).

build.sbt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ organization := "com.audienceproject"
22

33
name := "spark-dynamodb"
44

5-
version := "0.3.6"
5+
version := "0.4.0"
66

77
description := "Plug-and-play implementation of an Apache Spark custom data source for AWS DynamoDB."
88

@@ -13,7 +13,7 @@ crossScalaVersions := Seq("2.11.12", "2.12.7")
1313
resolvers += "DynamoDBLocal" at "https://s3-us-west-2.amazonaws.com/dynamodb-local/release"
1414

1515
libraryDependencies += "com.amazonaws" % "aws-java-sdk-dynamodb" % "1.11.466"
16-
libraryDependencies += "com.amazonaws" % "DynamoDBLocal" % "[1.11,2.0)" % "test"
16+
libraryDependencies += "com.amazonaws" % "DynamoDBLocal" % "[1.11,2.0)" % "test" exclude("com.google.guava", "guava")
1717

1818
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.0" % "provided"
1919
libraryDependencies += "com.google.guava" % "guava" % "14.0.1" % "provided"

project/plugins.sbt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ logLevel := Level.Warn
22

33
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.1")
44
addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.4")
5+
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")

src/main/scala/com/audienceproject/spark/dynamodb/DefaultSource.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ import org.apache.spark.sql.sources._
3131
import org.apache.spark.sql.types.StructType
3232
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
3333

34-
class DefaultSource extends RelationProvider
35-
with SchemaRelationProvider
36-
with CreatableRelationProvider {
34+
class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
3735

3836
val logger: Logger = LoggerFactory.getLogger(this.getClass.getName)
3937

@@ -50,19 +48,15 @@ class DefaultSource extends RelationProvider
5048
override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String],
5149
data: DataFrame): BaseRelation = {
5250
logger.info(s"Using Guava version $getGuavaVersion")
53-
val writeData =
54-
if (parameters.get("writePartitions").contains("skip")) data
55-
else data.repartition(parameters.get("writePartitions").map(_.toInt).getOrElse(sqlContext.sparkContext.defaultParallelism))
5651

57-
val writeRelation = new DynamoWriteRelation(writeData, parameters)(sqlContext)
58-
if (parameters.getOrElse("update", "false").toBoolean) {
52+
val writeRelation = new DynamoWriteRelation(data, parameters)(sqlContext)
53+
54+
if (parameters.getOrElse("update", "false").toBoolean)
5955
writeRelation.update()
60-
} else {
56+
else
6157
writeRelation.write()
6258

63-
}
6459
writeRelation
65-
6660
}
6761

6862
private def getGuavaVersion: String = try {

src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ package com.audienceproject.spark.dynamodb.connector
2323
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
2424
import com.amazonaws.auth.profile.ProfileCredentialsProvider
2525
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
26-
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
2726
import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
27+
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
2828
import org.apache.spark.sql.sources.Filter
2929

3030
private[dynamodb] trait DynamoConnector {
3131

32-
def getDynamoDB(region:Option[String]=None): DynamoDB = {
32+
def getDynamoDB(region: Option[String] = None): DynamoDB = {
3333
val client: AmazonDynamoDB = getDynamoDBClient(region)
3434
new DynamoDB(client)
3535
}
3636

37-
def getDynamoDBClient(region:Option[String]=None) = {
37+
def getDynamoDBClient(region: Option[String] = None): AmazonDynamoDB = {
3838
val chosenRegion = region.getOrElse(sys.env.getOrElse("aws.dynamodb.region", "us-east-1"))
3939
Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
4040
val credentials = Option(System.getProperty("aws.profile"))

src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoUpdatable.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.spark.sql.types.StructType
2525

2626
trait DynamoUpdatable {
2727

28-
2928
def updateItems(schema: StructType)(items: Iterator[Row]): Unit
3029

3130
}

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
*/
2121
package com.audienceproject.spark.dynamodb.connector
2222

23+
import java.util
24+
2325
import com.amazonaws.services.dynamodbv2.document._
24-
import com.amazonaws.services.dynamodbv2.document.spec.{BatchWriteItemSpec, ScanSpec}
25-
import com.amazonaws.services.dynamodbv2.model.{AttributeValue, ReturnConsumedCapacity, UpdateItemRequest, UpdateItemResult}
26+
import com.amazonaws.services.dynamodbv2.document.spec.{BatchWriteItemSpec, ScanSpec, UpdateItemSpec}
27+
import com.amazonaws.services.dynamodbv2.model.ReturnConsumedCapacity
2628
import com.amazonaws.services.dynamodbv2.xspec.ExpressionSpecBuilder
29+
import com.amazonaws.services.dynamodbv2.xspec.ExpressionSpecBuilder.{BOOL => newBOOL, L => newL, M => newM, N => newN, S => newS}
2730
import com.google.common.util.concurrent.RateLimiter
2831
import org.apache.spark.sql.Row
2932
import org.apache.spark.sql.sources.Filter
@@ -94,45 +97,6 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
9497
getDynamoDB(region).getTable(tableName).scan(scanSpec)
9598
}
9699

97-
override def updateItems(schema: StructType)(items: Iterator[Row]): Unit = {
98-
val columnNames = schema.map(_.name)
99-
val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
100-
val rangeKeyIndex = keySchema.rangeKeyName.map(columnNames.indexOf)
101-
val columnIndices = columnNames.zipWithIndex.filterNot({
102-
case (name, _) => keySchema match {
103-
case KeySchema(hashKey, None) => name == hashKey
104-
case KeySchema(hashKey, Some(rangeKey)) => name == hashKey || name == rangeKey
105-
}
106-
})
107-
108-
val rateLimiter = RateLimiter.create(writeLimit max 1)
109-
val client = getDynamoDBClient(region)
110-
111-
// For each item.
112-
items.foreach(row => {
113-
val key: Map[String, AttributeValue] = keySchema match {
114-
case KeySchema(hashKey, None) => Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType))
115-
case KeySchema(hashKey, Some(rangeKey)) =>
116-
Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType),
117-
rangeKey -> mapValueToAttributeValue(row(rangeKeyIndex.get), schema(rangeKey).dataType))
118-
119-
}
120-
val nonNullColumnIndices = columnIndices.filter(c => row(c._2) != null)
121-
val updateExpression = s"SET ${nonNullColumnIndices.map(c => s"${c._1}=:${c._1}").mkString(", ")}"
122-
val expressionAttributeValues = nonNullColumnIndices.map(c => s":${c._1}" -> mapValueToAttributeValue(row(c._2), schema(c._1).dataType)).toMap.asJava
123-
val updateItemReq = new UpdateItemRequest()
124-
.withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL)
125-
.withTableName(tableName)
126-
.withKey(key.asJava)
127-
.withUpdateExpression(updateExpression)
128-
.withExpressionAttributeValues(expressionAttributeValues)
129-
130-
val updateItemResult = client.updateItem(updateItemReq)
131-
132-
handleUpdateItemResult(rateLimiter)(updateItemResult)
133-
})
134-
}
135-
136100
override def putItems(schema: StructType, batchSize: Int)(items: Iterator[Row]): Unit = {
137101
val columnNames = schema.map(_.name)
138102
val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
@@ -174,46 +138,85 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
174138
))
175139

176140
val response = client.batchWriteItem(batchWriteItemSpec)
177-
178141
handleBatchWriteResponse(client, rateLimiter)(response)
179142
})
180143
}
181144

145+
override def updateItems(schema: StructType)(items: Iterator[Row]): Unit = {
146+
val columnNames = schema.map(_.name)
147+
val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
148+
val rangeKeyIndex = keySchema.rangeKeyName.map(columnNames.indexOf)
149+
val columnIndices = columnNames.zipWithIndex.filterNot({
150+
case (name, _) => keySchema match {
151+
case KeySchema(hashKey, None) => name == hashKey
152+
case KeySchema(hashKey, Some(rangeKey)) => name == hashKey || name == rangeKey
153+
}
154+
})
155+
156+
val rateLimiter = RateLimiter.create(writeLimit max 1)
157+
val client = getDynamoDB(region)
158+
159+
// For each item.
160+
items.foreach(row => {
161+
// Build update expression.
162+
val xspec = new ExpressionSpecBuilder()
163+
columnIndices.foreach({
164+
case (name, index) if !row.isNullAt(index) =>
165+
val updateAction = schema(name).dataType match {
166+
case StringType => newS(name).set(row.getString(index))
167+
case BooleanType => newBOOL(name).set(row.getBoolean(index))
168+
case IntegerType => newN(name).set(row.getInt(index))
169+
case LongType => newN(name).set(row.getLong(index))
170+
case ShortType => newN(name).set(row.getShort(index))
171+
case FloatType => newN(name).set(row.getFloat(index))
172+
case DoubleType => newN(name).set(row.getDouble(index))
173+
case ArrayType(innerType, _) => newL(name).set(row.getSeq[Any](index).map(e => mapValue(e, innerType)).asJava)
174+
case MapType(keyType, valueType, _) =>
175+
if (keyType != StringType) throw new IllegalArgumentException(
176+
s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
177+
newM(name).set(row.getMap[String, Any](index).mapValues(e => mapValue(e, valueType)).asJava)
178+
case StructType(fields) => newM(name).set(mapStruct(row.getStruct(index), fields))
179+
}
180+
xspec.addUpdate(updateAction)
181+
case _ =>
182+
})
183+
184+
val updateItemSpec = new UpdateItemSpec()
185+
.withExpressionSpec(xspec.buildForUpdate())
186+
.withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL)
187+
188+
// Map primary key.
189+
keySchema match {
190+
case KeySchema(hashKey, None) => updateItemSpec.withPrimaryKey(hashKey, row(hashKeyIndex))
191+
case KeySchema(hashKey, Some(rangeKey)) =>
192+
updateItemSpec.withPrimaryKey(hashKey, row(hashKeyIndex), rangeKey, row(rangeKeyIndex.get))
193+
}
194+
195+
if (updateItemSpec.getUpdateExpression.nonEmpty) {
196+
val response = client.getTable(tableName).updateItem(updateItemSpec)
197+
handleUpdateResponse(rateLimiter)(response)
198+
}
199+
})
200+
}
201+
182202
private def mapValue(element: Any, elementType: DataType): Any = {
183203
elementType match {
184204
case ArrayType(innerType, _) => element.asInstanceOf[Seq[_]].map(e => mapValue(e, innerType)).asJava
185205
case MapType(keyType, valueType, _) =>
186206
if (keyType != StringType) throw new IllegalArgumentException(
187207
s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
188-
element.asInstanceOf[Map[_, _]].mapValues(e => mapValue(e, valueType)).asJava
208+
element.asInstanceOf[Map[String, _]].mapValues(e => mapValue(e, valueType)).asJava
189209
case StructType(fields) =>
190210
val row = element.asInstanceOf[Row]
191-
(fields.indices map { i =>
192-
fields(i).name -> mapValue(row(i), fields(i).dataType)
193-
}).toMap.asJava
211+
mapStruct(row, fields)
194212
case _ => element
195213
}
196214
}
197215

198-
private def mapValueToAttributeValue(element: Any, elementType: DataType): AttributeValue = {
199-
elementType match {
200-
case ArrayType(innerType, _) => new AttributeValue().withL(element.asInstanceOf[Seq[_]].map(e => mapValueToAttributeValue(e, innerType)): _*)
201-
case MapType(keyType, valueType, _) =>
202-
if (keyType != StringType) throw new IllegalArgumentException(
203-
s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
204-
205-
new AttributeValue().withM(element.asInstanceOf[Map[String, _]].mapValues(e => mapValueToAttributeValue(e, valueType)).asJava)
206-
207-
case StructType(fields) =>
208-
val row = element.asInstanceOf[Row]
209-
new AttributeValue().withM((fields.indices map { i =>
210-
fields(i).name -> mapValueToAttributeValue(row(i), fields(i).dataType)
211-
}).toMap.asJava)
212-
case StringType => new AttributeValue().withS(element.asInstanceOf[String])
213-
case LongType | IntegerType | DoubleType | FloatType => new AttributeValue().withN(element.toString)
214-
case BooleanType => new AttributeValue().withBOOL(element.asInstanceOf[Boolean])
215-
}
216-
}
216+
private def mapStruct(row: Row, fields: Seq[StructField]): util.Map[String, Any] =
217+
(fields.indices map { i =>
218+
fields(i).name -> mapValue(row(i), fields(i).dataType)
219+
}).toMap.asJava
217220

218221
@tailrec
219222
private def handleBatchWriteResponse(client: DynamoDB, rateLimiter: RateLimiter)
@@ -231,12 +234,12 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
231234
}
232235
}
233236

234-
private def handleUpdateItemResult(rateLimiter: RateLimiter)
235-
(result: UpdateItemResult): Unit = {
237+
private def handleUpdateResponse(rateLimiter: RateLimiter)
238+
(response: UpdateItemOutcome): Unit = {
236239
// Rate limit on write capacity.
237-
if (result.getConsumedCapacity != null) {
238-
rateLimiter.acquire(result.getConsumedCapacity.getCapacityUnits.toInt)
239-
}
240+
Option(response.getUpdateItemResult.getConsumedCapacity)
241+
.map(_.getCapacityUnits.toInt)
242+
.foreach(rateLimiter.acquire)
240243
}
241244

242245
}

src/main/scala/com/audienceproject/spark/dynamodb/rdd/DynamoWriteRelation.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ private[dynamodb] class DynamoWriteRelation(data: DataFrame, parameters: Map[Str
3333

3434
private val tableName = parameters("tableName")
3535
private val batchSize = parameters.getOrElse("writeBatchSize", "25").toInt
36-
private val numPartitions = data.rdd.getNumPartitions
37-
private val connector = new TableConnector(tableName, numPartitions, parameters)
36+
private val connector = new TableConnector(tableName, sqlContext.sparkContext.defaultParallelism, parameters)
3837

3938
override val schema: StructType = data.schema
4039

@@ -48,5 +47,4 @@ private[dynamodb] class DynamoWriteRelation(data: DataFrame, parameters: Map[Str
4847
data.foreachPartition(connector.updateItems(schema) _)
4948
}
5049

51-
5250
}

src/test/scala/com/audienceproject/spark/dynamodb/WriteRelationTest.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class WriteRelationTest extends AbstractInMemoryTest {
5151
assert(validationDs.select("color").as[String].collect().forall(Seq("yellow", "orange", "red") contains _))
5252
assert(validationDs.select("weight").as[Double].collect().forall(Seq(0.1, 0.2, 0.2) contains _))
5353
}
54+
5455
test("Updating from a local Dataset with new and only some previous columns") {
5556
val tablename = "UpdateTest1"
5657
dynamoDB.createTable(new CreateTableRequest()
@@ -65,23 +66,22 @@ class WriteRelationTest extends AbstractInMemoryTest {
6566
("lemon", "yellow", 0.1),
6667
("orange", "orange", 0.2),
6768
("pomegranate", "red", 0.2)
68-
).toDF("name","color","weight")
69+
).toDF("name", "color", "weight")
6970
newItemsDs.write.dynamodb(tablename)
7071

7172
newItemsDs
72-
.withColumn("size",length($"color"))
73+
.withColumn("size", length($"color"))
7374
.drop("color")
74-
.withColumn("weight",$"weight"*2)
75-
.write.option("update","true").dynamodb(tablename)
75+
.withColumn("weight", $"weight" * 2)
76+
.write.option("update", "true").dynamodb(tablename)
7677

7778
val validationDs = spark.read.dynamodb(tablename)
7879
validationDs.show
7980
assert(validationDs.count() === 3)
8081
assert(validationDs.select("name").as[String].collect().forall(Seq("lemon", "orange", "pomegranate") contains _))
8182
assert(validationDs.select("color").as[String].collect().forall(Seq("yellow", "orange", "red") contains _))
8283
assert(validationDs.select("weight").as[Double].collect().forall(Seq(0.2, 0.4, 0.4) contains _))
83-
assert(validationDs.select("size").as[Long].collect().forall(Seq(6,3) contains _))
84-
84+
assert(validationDs.select("size").as[Long].collect().forall(Seq(6, 3) contains _))
8585
}
8686

8787
test("Updating from a local Dataset with null values") {
@@ -98,20 +98,20 @@ class WriteRelationTest extends AbstractInMemoryTest {
9898
("lemon", "yellow", 0.1),
9999
("orange", "orange", 0.2),
100100
("pomegranate", "red", 0.2)
101-
).toDF("name","color","weight")
101+
).toDF("name", "color", "weight")
102102
newItemsDs.write.dynamodb(tablename)
103103

104104
val alteredDs = newItemsDs
105-
.withColumn("weight",when($"weight" < 0.2,$"weight").otherwise(lit(null)))
105+
.withColumn("weight", when($"weight" < 0.2, $"weight").otherwise(lit(null)))
106106
alteredDs.show
107-
alteredDs.write.option("update","true").dynamodb(tablename)
107+
alteredDs.write.option("update", "true").dynamodb(tablename)
108108

109109
val validationDs = spark.read.dynamodb(tablename)
110110
validationDs.show
111111
assert(validationDs.count() === 3)
112112
assert(validationDs.select("name").as[String].collect().forall(Seq("lemon", "orange", "pomegranate") contains _))
113113
assert(validationDs.select("color").as[String].collect().forall(Seq("yellow", "orange", "red") contains _))
114114
assert(validationDs.select("weight").as[Double].collect().forall(Seq(0.2, 0.1) contains _))
115-
116115
}
116+
117117
}

0 commit comments

Comments
 (0)