Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

use async client for update, because we need speed! #11

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ scalaVersion := "2.11.12"
resolvers += "DynamoDBLocal" at "https://s3-us-west-2.amazonaws.com/dynamodb-local/release"

libraryDependencies += "com.amazonaws" % "aws-java-sdk-dynamodb" % "1.11.325"
// https://mvnrepository.com/artifact/com.amazonaws/amazon-dax-client
libraryDependencies += "com.amazonaws" % "amazon-dax-client" % "1.0.200704.0"
libraryDependencies += "com.amazonaws" % "DynamoDBLocal" % "[1.11,2.0)" % "test"

libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.3.1" % "provided"
Expand All @@ -28,9 +30,9 @@ libraryDependencies ++= {
"org.apache.logging.log4j" % "log4j-slf4j-impl" % log4j2Version % "test"
)
}

test in assembly := {}
fork in Test := true
javaOptions in Test ++= Seq("-Djava.library.path=./lib/sqlite4java", "-Daws.dynamodb.endpoint=http://localhost:8000")
javaOptions in Test ++= Seq("-Djava.library.path=./lib/sqlite4java", "-Daws.dynamodb.endpoint=http://localhost:8000" ,"-Daws.accessKeyId=asdf","-Daws.secretKey=asdf")

/**
* Maven specific settings for publishing to Maven central.
Expand Down
1 change: 1 addition & 0 deletions project/assembly.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.7")
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,55 @@ package com.audienceproject.spark.dynamodb.connector
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.amazonaws.auth.profile.ProfileCredentialsProvider
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBAsyncClientBuilder, AmazonDynamoDBClientBuilder}
import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
import org.apache.spark.sql.sources.Filter
import com.amazon.dax.client.dynamodbv2.AmazonDaxClientBuilder
import com.amazonaws.regions.DefaultAwsRegionProviderChain

private[dynamodb] trait DynamoConnector {

def getDynamoDB(region:Option[String]=None): DynamoDB = {
val client: AmazonDynamoDB = getDynamoDBClient(region)
def getDynamoDB(parameters: Map[String, String] = Map.empty): DynamoDB = {
val client: AmazonDynamoDB = getDynamoDBClient(parameters)
new DynamoDB(client)
}

def getDynamoDBClient(region:Option[String]=None) = {
val chosenRegion = region.getOrElse(sys.env.getOrElse("aws.dynamodb.region", "us-east-1"))

def getDynamoDBClient(parameters: Map[String, String] = Map.empty) = {
val builder = AmazonDynamoDBClientBuilder.standard()
val credentials = Option(System.getProperty("aws.profile"))
.map(new ProfileCredentialsProvider(_))
.getOrElse(DefaultAWSCredentialsProviderChain.getInstance())

val region: String = parameters.get("region").orElse(sys.env.get("aws.dynamodb.region")).getOrElse(new DefaultAwsRegionProviderChain().getRegion)
builder.withCredentials(credentials)

Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
val credentials = Option(System.getProperty("aws.profile"))
.map(new ProfileCredentialsProvider(_))
.getOrElse(new DefaultAWSCredentialsProviderChain)
AmazonDynamoDBClientBuilder.standard()
.withCredentials(credentials)
.withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
.build()
}).getOrElse(AmazonDynamoDBClientBuilder.standard().withRegion(chosenRegion).build())
builder
.withEndpointConfiguration(new EndpointConfiguration(endpoint, region))
}).getOrElse(
builder.withRegion(region)).build()

}

def getDynamoDBAsyncClient(parameters: Map[String, String] = Map.empty) = {
val builder = AmazonDynamoDBAsyncClientBuilder.standard()
val credentials = Option(System.getProperty("aws.profile"))
.map(new ProfileCredentialsProvider(_))
.getOrElse(DefaultAWSCredentialsProviderChain.getInstance())

val region: String = parameters.get("region").orElse(sys.env.get("aws.dynamodb.region")).getOrElse(new DefaultAwsRegionProviderChain().getRegion)
builder.withCredentials(credentials)

Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
builder
.withEndpointConfiguration(new EndpointConfiguration(endpoint, region))
}).getOrElse(
builder.withRegion(region)).build()

}


val keySchema: KeySchema

val readLimit: Double
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ import org.apache.spark.sql.types.StructType
trait DynamoUpdatable {


def updateItems(schema: StructType)(items: Iterator[Row]): Unit
def updateItems(schema: StructType,batchSize:Int)(items: Iterator[Row]): Unit

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa

private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
private val region = parameters.get("region")

override val (keySchema, readLimit, writeLimit, itemLimit, totalSizeInBytes) = {
val table = getDynamoDB(region).getTable(tableName)
val table = getDynamoDB(parameters).getTable(tableName)
val desc = table.describe()

// Key schema.
Expand Down Expand Up @@ -83,10 +82,10 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
scanSpec.withExpressionSpec(xspec.buildForScan())
}

getDynamoDB(region).getTable(tableName).scan(scanSpec)
getDynamoDB(parameters).getTable(tableName).scan(scanSpec)
}

override def updateItems(schema: StructType)(items: Iterator[Row]): Unit = {
override def updateItems(schema: StructType, batchSize: Int)(items: Iterator[Row]): Unit = {
val columnNames = schema.map(_.name)
val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
val rangeKeyIndex = keySchema.rangeKeyName.map(columnNames.indexOf)
Expand All @@ -98,32 +97,38 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
})

val rateLimiter = RateLimiter.create(writeLimit max 1)
val client = getDynamoDBClient(region)
val client = getDynamoDBAsyncClient(parameters)



// For each item.
items.foreach(row => {
val key:Map[String,AttributeValue] = keySchema match {
case KeySchema(hashKey, None) => Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType))
case KeySchema(hashKey, Some(rangeKey)) =>
Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType),
rangeKey-> mapValueToAttributeValue(row(rangeKeyIndex.get), schema(rangeKey).dataType))

}
val nonNullColumnIndices =columnIndices.filter(c => row(c._2)!=null)
val updateExpression = s"SET ${nonNullColumnIndices.map(c => s"${c._1}=:${c._1}").mkString(", ")}"
val expressionAttributeValues = nonNullColumnIndices.map(c => s":${c._1}" -> mapValueToAttributeValue(row(c._2), schema(c._1).dataType)).toMap.asJava
val updateItemReq = new UpdateItemRequest()
.withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL)
.withTableName(tableName)
.withKey(key.asJava)
.withUpdateExpression(updateExpression)
.withExpressionAttributeValues(expressionAttributeValues)

val updateItemResult = client.updateItem(updateItemReq)

handleUpdateItemResult(rateLimiter)(updateItemResult)
items.grouped(batchSize).foreach(itemBatch => {
val results = itemBatch.map(row => {
val key:Map[String,AttributeValue] = keySchema match {
case KeySchema(hashKey, None) => Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType))
case KeySchema(hashKey, Some(rangeKey)) =>
Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType),
rangeKey-> mapValueToAttributeValue(row(rangeKeyIndex.get), schema(rangeKey).dataType))

}
val nonNullColumnIndices =columnIndices.filter(c => row(c._2)!=null)
val updateExpression = s"SET ${nonNullColumnIndices.map(c => s"#${c._2}=:${c._2}").mkString(", ")}"
val expressionAttributeValues = nonNullColumnIndices.map(c => s":${c._2}" -> mapValueToAttributeValue(row(c._2), schema(c._1).dataType)).toMap.asJava
val updateItemReq = new UpdateItemRequest()
.withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL)
.withTableName(tableName)
.withKey(key.asJava)
.withUpdateExpression(updateExpression)
.withExpressionAttributeNames(nonNullColumnIndices.map(c=>s"#${c._2}" -> c._1).toMap.asJava)
.withExpressionAttributeValues(expressionAttributeValues)

client.updateItemAsync(updateItemReq)
})
val unitsSpent = results.map(f => (try { Option(f.get()) } catch { case _:Exception => Option.empty })
.flatMap(c => Option(c.getConsumedCapacity))
.map(_.getCapacityUnits)
.getOrElse(Double.box(1.0))).reduce((a,b)=>a+b)
rateLimiter.acquire(unitsSpent.toInt)
})
}

Expand All @@ -139,7 +144,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
})

val rateLimiter = RateLimiter.create(writeLimit max 1)
val client = getDynamoDB(region)
val client = getDynamoDB(parameters)

// For each batch.
items.grouped(batchSize).foreach(itemBatch => {
Expand Down Expand Up @@ -224,12 +229,5 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
handleBatchWriteResponse(client, rateLimiter)(newResponse)
}
}
private def handleUpdateItemResult(rateLimiter: RateLimiter)
(result: UpdateItemResult): Unit = {
// Rate limit on write capacity.
if (result.getConsumedCapacity != null) {
rateLimiter.acquire(result.getConsumedCapacity.getCapacityUnits.toInt)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
private val region = parameters.get("region")

override val (keySchema, readLimit, itemLimit, totalSizeInBytes) = {
val table = getDynamoDB(region).getTable(tableName)
val table = getDynamoDB(parameters).getTable(tableName)
val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get

// Key schema.
Expand Down Expand Up @@ -76,7 +76,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
scanSpec.withExpressionSpec(xspec.buildForScan())
}

getDynamoDB(region).getTable(tableName).getIndex(indexName).scan(scanSpec)
getDynamoDB(parameters).getTable(tableName).getIndex(indexName).scan(scanSpec)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[dynamodb] class DynamoWriteRelation(data: DataFrame, parameters: Map[Str
}

def update(): Unit = {
data.foreachPartition(connector.updateItems(schema) _)
data.foreachPartition(connector.updateItems(schema,batchSize) _)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class WriteRelationTest extends AbstractInMemoryTest {
newItemsDs.write.dynamodb(tablename)

newItemsDs
.withColumn("size",length($"color"))
.withColumn("si:ze",length($"color"))
.drop("color")
.withColumn("weight",$"weight"*2)
.write.option("update","true").dynamodb(tablename)
Expand All @@ -80,7 +80,7 @@ class WriteRelationTest extends AbstractInMemoryTest {
assert(validationDs.select("name").as[String].collect().forall(Seq("lemon", "orange", "pomegranate") contains _))
assert(validationDs.select("color").as[String].collect().forall(Seq("yellow", "orange", "red") contains _))
assert(validationDs.select("weight").as[Double].collect().forall(Seq(0.2, 0.4, 0.4) contains _))
assert(validationDs.select("size").as[Long].collect().forall(Seq(6,3) contains _))
assert(validationDs.select("si:ze").as[Long].collect().forall(Seq(6,3) contains _))

}

Expand Down