Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit c706ed7

Browse files
authored
Merge branch 'master' into tmwong2003/fix-readme-call-to-read
2 parents f813833 + ceb9cb1 commit c706ed7

File tree

13 files changed

+294
-34
lines changed

13 files changed

+294
-34
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,21 @@ val dynamoDf = spark.read.dynamodb("SomeTableName") // <-- DataFrame of Row obje
2323
// Scan the table for the first 100 items (the order is arbitrary) and print them.
2424
dynamoDf.show(100)
2525

26+
// write to some other table overwriting existing item with same keys
27+
dynamoDf.write.dynamodb("SomeOtherTable")
28+
2629
// Case class representing the items in our table.
2730
import com.audienceproject.spark.dynamodb.attribute
2831
case class Vegetable (name: String, color: String, @attribute("weight_kg") weightKg: Double)
2932

3033
// Load a Dataset[Vegetable]. Notice the @attribute annotation on the case class - we imagine the weight attribute is named with an underscore in DynamoDB.
3134
import org.apache.spark.sql.functions._
3235
import spark.implicits._
33-
val vegetableDs = spark.dynamodbAs[Vegetable]("VegeTable")
36+
val vegetableDs = spark.read.dynamodbAs[Vegetable]("VegeTable")
3437
val avgWeightByColor = vegetableDs.agg($"color", avg($"weightKg")) // The column is called 'weightKg' in the Dataset.
38+
39+
40+
3541
```
3642

3743
## Getting The Dependency
@@ -41,6 +47,10 @@ The library is available from Maven Central. Add the dependency in SBT as ```"co
4147
Spark is used in the library as a "provided" dependency, which means Spark has to be installed separately on the container where the application is running, such as is the case on AWS EMR.
4248

4349
## Parameters
50+
The following parameters can be set as options on the Spark reader and writer object before loading/saving.
51+
- `region` sets the region where the dynamodb table. Default is environment specific.
52+
53+
4454
The following parameters can be set as options on the Spark reader object before loading.
4555

4656
- `readPartitions` number of partitions to split the initial RDD when loading the data into Spark. Corresponds 1-to-1 with total number of segments in the DynamoDB parallel scan used to load the data. Defaults to `sparkContext.defaultParallelism`
@@ -53,6 +63,7 @@ The following parameters can be set as options on the Spark writer object before
5363

5464
- `writePartitions` number of partitions to split the given DataFrame into when writing to DynamoDB. Set to `skip` to avoid repartitioning the DataFrame before writing. Defaults to `sparkContext.defaultParallelism`
5565
- `writeBatchSize` number of items to send per call to DynamoDB BatchWriteItem. Default 25.
66+
- `update` if true writes will be using UpdateItem on keys rather than BatchWriteItem. Default false
5667

5768
## Running Unit Tests
5869
The unit tests are dependent on the AWS DynamoDBLocal client, which in turn is dependent on [sqlite4java](https://bitbucket.org/almworks/sqlite4java/src/master/). I had some problems running this on OSX, so I had to put the library directly in the /lib folder, as graciously explained in [this Stack Overflow answer](https://stackoverflow.com/questions/34137043/amazon-dynamodb-local-unknown-error-exception-or-failure/35353377#35353377).

build.sbt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ organization := "com.audienceproject"
22

33
name := "spark-dynamodb"
44

5-
version := "0.3.1"
5+
version := "0.3.2"
66

77
description := "Plug-and-play implementation of an Apache Spark custom data source for AWS DynamoDB."
88

@@ -63,4 +63,11 @@ pomExtra := <url>https://github.com/audienceproject/spark-dynamodb</url>
6363
<organization>AudienceProject</organization>
6464
<organizationUrl>https://www.audienceproject.com</organizationUrl>
6565
</developer>
66+
<developer>
67+
<id>johsbk</id>
68+
<name>Johs Kristoffersen</name>
69+
<email>johs.kristoffersen@audienceproject.com</email>
70+
<organization>AudienceProject</organization>
71+
<organizationUrl>https://www.audienceproject.com</organizationUrl>
72+
</developer>
6673
</developers>

src/main/scala/com/audienceproject/spark/dynamodb/DefaultSource.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,25 @@ class DefaultSource extends RelationProvider
5454
if (parameters.get("writePartitions").contains("skip")) data
5555
else data.repartition(parameters.get("writePartitions").map(_.toInt).getOrElse(sqlContext.sparkContext.defaultParallelism))
5656

57-
val writeRelation = new DynamoWriteRelation(writeData, parameters)(sqlContext)
57+
val writeRelation= new DynamoWriteRelation(writeData, parameters)(sqlContext)
58+
if (parameters.getOrElse("update","false").toBoolean) {
59+
writeRelation.update()
60+
} else {
61+
writeRelation.write()
5862

59-
writeRelation.write()
63+
}
6064
writeRelation
65+
6166
}
6267

6368
private def getGuavaVersion: String = try {
6469
val file = new File(classOf[Charsets].getProtectionDomain.getCodeSource.getLocation.toURI)
65-
try {
66-
val jar = new JarFile(file)
67-
try
68-
jar.getManifest.getMainAttributes.getValue("Bundle-Version")
69-
finally if (jar != null) jar.close()
70-
}
70+
val jar = new JarFile(file)
71+
try
72+
jar.getManifest.getMainAttributes.getValue("Bundle-Version")
73+
finally if (jar != null) jar.close()
7174
} catch {
72-
case ex: Exception =>
73-
throw new RuntimeException("Unable to get the version of Guava", ex)
75+
case ex: Exception => throw new RuntimeException("Unable to get the version of Guava", ex)
7476
}
7577

7678
}

src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,33 @@ package com.audienceproject.spark.dynamodb.connector
2323
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
2424
import com.amazonaws.auth.profile.ProfileCredentialsProvider
2525
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
26-
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder
26+
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
2727
import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
2828
import org.apache.spark.sql.sources.Filter
2929

3030
private[dynamodb] trait DynamoConnector {
3131

32-
def getClient: DynamoDB = {
33-
val client = Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
34-
val region = sys.env.getOrElse("aws.dynamodb.region", "us-east-1")
32+
def getDynamoDB(region:Option[String]=None): DynamoDB = {
33+
val client: AmazonDynamoDB = getDynamoDBClient(region)
34+
new DynamoDB(client)
35+
}
36+
37+
def getDynamoDBClient(region:Option[String]=None) = {
38+
val chosenRegion = region.getOrElse(sys.env.getOrElse("aws.dynamodb.region", "us-east-1"))
39+
Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
3540
val credentials = Option(System.getProperty("aws.profile"))
3641
.map(new ProfileCredentialsProvider(_))
3742
.getOrElse(new DefaultAWSCredentialsProviderChain)
3843
AmazonDynamoDBClientBuilder.standard()
3944
.withCredentials(credentials)
40-
.withEndpointConfiguration(new EndpointConfiguration(endpoint, region))
45+
.withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
4146
.build()
42-
}).getOrElse(AmazonDynamoDBClientBuilder.defaultClient())
43-
new DynamoDB(client)
47+
}).getOrElse(AmazonDynamoDBClientBuilder.standard().withRegion(chosenRegion).build())
4448
}
4549

4650
val keySchema: KeySchema
4751

48-
val readLimit: Int
52+
val readLimit: Double
4953

5054
val itemLimit: Int
5155

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/**
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*
19+
* Copyright © 2018 AudienceProject. All rights reserved.
20+
*/
21+
package com.audienceproject.spark.dynamodb.connector
22+
23+
import org.apache.spark.sql.Row
24+
import org.apache.spark.sql.types.StructType
25+
26+
trait DynamoUpdatable {
27+
28+
29+
def updateItems(schema: StructType)(items: Iterator[Row]): Unit
30+
31+
}

src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoWritable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType
2525

2626
trait DynamoWritable {
2727

28-
val writeLimit: Int
28+
val writeLimit: Double
2929

3030
def putItems(schema: StructType, batchSize: Int)(items: Iterator[Row]): Unit
3131

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package com.audienceproject.spark.dynamodb.connector
2222

2323
import com.amazonaws.services.dynamodbv2.document._
2424
import com.amazonaws.services.dynamodbv2.document.spec.{BatchWriteItemSpec, ScanSpec}
25-
import com.amazonaws.services.dynamodbv2.model.ReturnConsumedCapacity
25+
import com.amazonaws.services.dynamodbv2.model.{AttributeValue, ReturnConsumedCapacity, UpdateItemRequest, UpdateItemResult}
2626
import com.amazonaws.services.dynamodbv2.xspec.ExpressionSpecBuilder
2727
import com.google.common.util.concurrent.RateLimiter
2828
import org.apache.spark.sql.Row
@@ -33,13 +33,14 @@ import scala.annotation.tailrec
3333
import scala.collection.JavaConverters._
3434

3535
private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, parameters: Map[String, String])
36-
extends DynamoConnector with DynamoWritable with Serializable {
36+
extends DynamoConnector with DynamoWritable with DynamoUpdatable with Serializable {
3737

3838
private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
3939
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
40+
private val region = parameters.get("region")
4041

4142
override val (keySchema, readLimit, writeLimit, itemLimit, totalSizeInBytes) = {
42-
val table = getClient.getTable(tableName)
43+
val table = getDynamoDB(region).getTable(tableName)
4344
val desc = table.describe()
4445

4546
// Key schema.
@@ -56,10 +57,10 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
5657
val readCapacity = desc.getProvisionedThroughput.getReadCapacityUnits * targetCapacity
5758
val writeCapacity = desc.getProvisionedThroughput.getWriteCapacityUnits * targetCapacity
5859

59-
val readLimit = (readCapacity / totalSegments).toInt max 1
60-
val itemLimit = (bytesPerRCU / avgItemSize * readLimit).toInt * readFactor
60+
val readLimit = readCapacity / totalSegments
61+
val itemLimit = ((bytesPerRCU / avgItemSize * readLimit).toInt * readFactor) max 1
6162

62-
val writeLimit = (writeCapacity / totalSegments).toInt
63+
val writeLimit = writeCapacity / totalSegments
6364

6465
(keySchema, readLimit, writeLimit, itemLimit, tableSize.toLong)
6566
}
@@ -82,7 +83,48 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
8283
scanSpec.withExpressionSpec(xspec.buildForScan())
8384
}
8485

85-
getClient.getTable(tableName).scan(scanSpec)
86+
getDynamoDB(region).getTable(tableName).scan(scanSpec)
87+
}
88+
89+
override def updateItems(schema: StructType)(items: Iterator[Row]): Unit = {
90+
val columnNames = schema.map(_.name)
91+
val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
92+
val rangeKeyIndex = keySchema.rangeKeyName.map(columnNames.indexOf)
93+
val columnIndices = columnNames.zipWithIndex.filterNot({
94+
case (name, _) => keySchema match {
95+
case KeySchema(hashKey, None) => name == hashKey
96+
case KeySchema(hashKey, Some(rangeKey)) => name == hashKey || name == rangeKey
97+
}
98+
})
99+
100+
val rateLimiter = RateLimiter.create(writeLimit max 1)
101+
val client = getDynamoDBClient(region)
102+
103+
104+
105+
// For each item.
106+
items.foreach(row => {
107+
val key:Map[String,AttributeValue] = keySchema match {
108+
case KeySchema(hashKey, None) => Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType))
109+
case KeySchema(hashKey, Some(rangeKey)) =>
110+
Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType),
111+
rangeKey-> mapValueToAttributeValue(row(rangeKeyIndex.get), schema(rangeKey).dataType))
112+
113+
}
114+
val nonNullColumnIndices =columnIndices.filter(c => row(c._2)!=null)
115+
val updateExpression = s"SET ${nonNullColumnIndices.map(c => s"${c._1}=:${c._1}").mkString(", ")}"
116+
val expressionAttributeValues = nonNullColumnIndices.map(c => s":${c._1}" -> mapValueToAttributeValue(row(c._2), schema(c._1).dataType)).toMap.asJava
117+
val updateItemReq = new UpdateItemRequest()
118+
.withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL)
119+
.withTableName(tableName)
120+
.withKey(key.asJava)
121+
.withUpdateExpression(updateExpression)
122+
.withExpressionAttributeValues(expressionAttributeValues)
123+
124+
val updateItemResult = client.updateItem(updateItemReq)
125+
126+
handleUpdateItemResult(rateLimiter)(updateItemResult)
127+
})
86128
}
87129

88130
override def putItems(schema: StructType, batchSize: Int)(items: Iterator[Row]): Unit = {
@@ -97,7 +139,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
97139
})
98140

99141
val rateLimiter = RateLimiter.create(writeLimit max 1)
100-
val client = getClient
142+
val client = getDynamoDB(region)
101143

102144
// For each batch.
103145
items.grouped(batchSize).foreach(itemBatch => {
@@ -147,6 +189,26 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
147189
}
148190
}
149191

192+
private def mapValueToAttributeValue(element: Any, elementType: DataType): AttributeValue = {
193+
elementType match {
194+
case ArrayType(innerType, _) => new AttributeValue().withL(element.asInstanceOf[Seq[_]].map(e => mapValueToAttributeValue(e, innerType)):_*)
195+
case MapType(keyType, valueType, _) =>
196+
if (keyType != StringType) throw new IllegalArgumentException(
197+
s"Invalid Map key type '${keyType.typeName}'. DynamoDB only supports String as Map key type.")
198+
199+
new AttributeValue().withM(element.asInstanceOf[Map[String, _]].mapValues(e => mapValueToAttributeValue(e, valueType)).asJava)
200+
201+
case StructType(fields) =>
202+
val row = element.asInstanceOf[Row]
203+
new AttributeValue().withM( (fields.indices map { i =>
204+
fields(i).name -> mapValueToAttributeValue(row(i), fields(i).dataType)
205+
}).toMap.asJava)
206+
case StringType => new AttributeValue().withS(element.asInstanceOf[String])
207+
case LongType | IntegerType | DoubleType | FloatType => new AttributeValue().withN(element.toString)
208+
case BooleanType => new AttributeValue().withBOOL(element.asInstanceOf[Boolean])
209+
}
210+
}
211+
150212
@tailrec
151213
private def handleBatchWriteResponse(client: DynamoDB, rateLimiter: RateLimiter)
152214
(response: BatchWriteItemOutcome): Unit = {
@@ -162,5 +224,12 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
162224
handleBatchWriteResponse(client, rateLimiter)(newResponse)
163225
}
164226
}
227+
private def handleUpdateItemResult(rateLimiter: RateLimiter)
228+
(result: UpdateItemResult): Unit = {
229+
// Rate limit on write capacity.
230+
if (result.getConsumedCapacity != null) {
231+
rateLimiter.acquire(result.getConsumedCapacity.getCapacityUnits.toInt)
232+
}
233+
}
165234

166235
}

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
3333

3434
private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
3535
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
36+
private val region = parameters.get("region")
3637

3738
override val (keySchema, readLimit, itemLimit, totalSizeInBytes) = {
38-
val table = getClient.getTable(tableName)
39+
val table = getDynamoDB(region).getTable(tableName)
3940
val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get
4041

4142
// Key schema.
@@ -51,8 +52,8 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
5152
val avgItemSize = tableSize.toDouble / indexDesc.getItemCount
5253
val readCapacity = indexDesc.getProvisionedThroughput.getReadCapacityUnits * targetCapacity
5354

54-
val rateLimit = (readCapacity / totalSegments).toInt max 1
55-
val itemLimit = (bytesPerRCU / avgItemSize * rateLimit).toInt * readFactor
55+
val rateLimit = readCapacity / totalSegments
56+
val itemLimit = ((bytesPerRCU / avgItemSize * rateLimit).toInt * readFactor) max 1
5657

5758
(keySchema, rateLimit, itemLimit, tableSize.toLong)
5859
}
@@ -75,7 +76,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
7576
scanSpec.withExpressionSpec(xspec.buildForScan())
7677
}
7778

78-
getClient.getTable(tableName).getIndex(indexName).scan(scanSpec)
79+
getDynamoDB(region).getTable(tableName).getIndex(indexName).scan(scanSpec)
7980
}
8081

8182
}

src/main/scala/com/audienceproject/spark/dynamodb/rdd/DynamoWriteRelation.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,9 @@ private[dynamodb] class DynamoWriteRelation(data: DataFrame, parameters: Map[Str
4444
data.foreachPartition(connector.putItems(schema, batchSize) _)
4545
}
4646

47+
def update(): Unit = {
48+
data.foreachPartition(connector.updateItems(schema) _)
49+
}
50+
51+
4752
}

src/main/scala/com/audienceproject/spark/dynamodb/rdd/ScanPartition.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private[dynamodb] class ScanPartition(schema: StructType,
4646

4747
if (connector.isEmpty) return Iterator.empty
4848

49-
val rateLimiter = RateLimiter.create(connector.readLimit max 1)
49+
val rateLimiter = RateLimiter.create(connector.readLimit)
5050

5151
val scanResult = connector.scan(index, requiredColumns, filters)
5252

0 commit comments

Comments
 (0)