diff --git a/build.sbt b/build.sbt index 4d29a5e..38eef94 100644 --- a/build.sbt +++ b/build.sbt @@ -11,6 +11,8 @@ scalaVersion := "2.11.12" resolvers += "DynamoDBLocal" at "https://s3-us-west-2.amazonaws.com/dynamodb-local/release" libraryDependencies += "com.amazonaws" % "aws-java-sdk-dynamodb" % "1.11.325" +// https://mvnrepository.com/artifact/com.amazonaws/amazon-dax-client +libraryDependencies += "com.amazonaws" % "amazon-dax-client" % "1.0.200704.0" libraryDependencies += "com.amazonaws" % "DynamoDBLocal" % "[1.11,2.0)" % "test" libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.3.1" % "provided" @@ -28,9 +30,9 @@ libraryDependencies ++= { "org.apache.logging.log4j" % "log4j-slf4j-impl" % log4j2Version % "test" ) } - +test in assembly := {} fork in Test := true -javaOptions in Test ++= Seq("-Djava.library.path=./lib/sqlite4java", "-Daws.dynamodb.endpoint=http://localhost:8000") +javaOptions in Test ++= Seq("-Djava.library.path=./lib/sqlite4java", "-Daws.dynamodb.endpoint=http://localhost:8000" ,"-Daws.accessKeyId=asdf","-Daws.secretKey=asdf") /** * Maven specific settings for publishing to Maven central. diff --git a/project/assembly.sbt b/project/assembly.sbt new file mode 100644 index 0000000..d95475f --- /dev/null +++ b/project/assembly.sbt @@ -0,0 +1 @@ +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.7") diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala b/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala index dc382a2..5ad1539 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala @@ -23,30 +23,55 @@ package com.audienceproject.spark.dynamodb.connector import com.amazonaws.auth.DefaultAWSCredentialsProviderChain import com.amazonaws.auth.profile.ProfileCredentialsProvider import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration -import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder} +import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBAsyncClientBuilder, AmazonDynamoDBClientBuilder} import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome} import org.apache.spark.sql.sources.Filter +import com.amazon.dax.client.dynamodbv2.AmazonDaxClientBuilder +import com.amazonaws.regions.DefaultAwsRegionProviderChain private[dynamodb] trait DynamoConnector { - def getDynamoDB(region:Option[String]=None): DynamoDB = { - val client: AmazonDynamoDB = getDynamoDBClient(region) + def getDynamoDB(parameters: Map[String, String] = Map.empty): DynamoDB = { + val client: AmazonDynamoDB = getDynamoDBClient(parameters) new DynamoDB(client) } - def getDynamoDBClient(region:Option[String]=None) = { - val chosenRegion = region.getOrElse(sys.env.getOrElse("aws.dynamodb.region", "us-east-1")) + + def getDynamoDBClient(parameters: Map[String, String] = Map.empty) = { + val builder = AmazonDynamoDBClientBuilder.standard() + val credentials = Option(System.getProperty("aws.profile")) + .map(new ProfileCredentialsProvider(_)) + .getOrElse(DefaultAWSCredentialsProviderChain.getInstance()) + + val region: String = parameters.get("region").orElse(sys.env.get("aws.dynamodb.region")).getOrElse(new DefaultAwsRegionProviderChain().getRegion) + builder.withCredentials(credentials) + Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => { - val credentials = Option(System.getProperty("aws.profile")) - .map(new ProfileCredentialsProvider(_)) - .getOrElse(new DefaultAWSCredentialsProviderChain) - AmazonDynamoDBClientBuilder.standard() - .withCredentials(credentials) - .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion)) - .build() - }).getOrElse(AmazonDynamoDBClientBuilder.standard().withRegion(chosenRegion).build()) + builder + .withEndpointConfiguration(new EndpointConfiguration(endpoint, region)) + }).getOrElse( + builder.withRegion(region)).build() + } + def getDynamoDBAsyncClient(parameters: Map[String, String] = Map.empty) = { + val builder = AmazonDynamoDBAsyncClientBuilder.standard() + val credentials = Option(System.getProperty("aws.profile")) + .map(new ProfileCredentialsProvider(_)) + .getOrElse(DefaultAWSCredentialsProviderChain.getInstance()) + + val region: String = parameters.get("region").orElse(sys.env.get("aws.dynamodb.region")).getOrElse(new DefaultAwsRegionProviderChain().getRegion) + builder.withCredentials(credentials) + + Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => { + builder + .withEndpointConfiguration(new EndpointConfiguration(endpoint, region)) + }).getOrElse( + builder.withRegion(region)).build() + + } + + val keySchema: KeySchema val readLimit: Double diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoUpdatable.scala b/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoUpdatable.scala index 4846aed..d295f41 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoUpdatable.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoUpdatable.scala @@ -26,6 +26,6 @@ import org.apache.spark.sql.types.StructType trait DynamoUpdatable { - def updateItems(schema: StructType)(items: Iterator[Row]): Unit + def updateItems(schema: StructType,batchSize:Int)(items: Iterator[Row]): Unit } diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala index aa07529..db8d7e9 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala @@ -37,10 +37,9 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean - private val region = parameters.get("region") override val (keySchema, readLimit, writeLimit, itemLimit, totalSizeInBytes) = { - val table = getDynamoDB(region).getTable(tableName) + val table = getDynamoDB(parameters).getTable(tableName) val desc = table.describe() // Key schema. @@ -83,10 +82,10 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa scanSpec.withExpressionSpec(xspec.buildForScan()) } - getDynamoDB(region).getTable(tableName).scan(scanSpec) + getDynamoDB(parameters).getTable(tableName).scan(scanSpec) } - override def updateItems(schema: StructType)(items: Iterator[Row]): Unit = { + override def updateItems(schema: StructType, batchSize: Int)(items: Iterator[Row]): Unit = { val columnNames = schema.map(_.name) val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName) val rangeKeyIndex = keySchema.rangeKeyName.map(columnNames.indexOf) @@ -98,32 +97,38 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa }) val rateLimiter = RateLimiter.create(writeLimit max 1) - val client = getDynamoDBClient(region) + val client = getDynamoDBAsyncClient(parameters) // For each item. - items.foreach(row => { - val key:Map[String,AttributeValue] = keySchema match { - case KeySchema(hashKey, None) => Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType)) - case KeySchema(hashKey, Some(rangeKey)) => - Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType), - rangeKey-> mapValueToAttributeValue(row(rangeKeyIndex.get), schema(rangeKey).dataType)) - - } - val nonNullColumnIndices =columnIndices.filter(c => row(c._2)!=null) - val updateExpression = s"SET ${nonNullColumnIndices.map(c => s"${c._1}=:${c._1}").mkString(", ")}" - val expressionAttributeValues = nonNullColumnIndices.map(c => s":${c._1}" -> mapValueToAttributeValue(row(c._2), schema(c._1).dataType)).toMap.asJava - val updateItemReq = new UpdateItemRequest() - .withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL) - .withTableName(tableName) - .withKey(key.asJava) - .withUpdateExpression(updateExpression) - .withExpressionAttributeValues(expressionAttributeValues) - - val updateItemResult = client.updateItem(updateItemReq) - - handleUpdateItemResult(rateLimiter)(updateItemResult) + items.grouped(batchSize).foreach(itemBatch => { + val results = itemBatch.map(row => { + val key:Map[String,AttributeValue] = keySchema match { + case KeySchema(hashKey, None) => Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType)) + case KeySchema(hashKey, Some(rangeKey)) => + Map(hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType), + rangeKey-> mapValueToAttributeValue(row(rangeKeyIndex.get), schema(rangeKey).dataType)) + + } + val nonNullColumnIndices =columnIndices.filter(c => row(c._2)!=null) + val updateExpression = s"SET ${nonNullColumnIndices.map(c => s"#${c._2}=:${c._2}").mkString(", ")}" + val expressionAttributeValues = nonNullColumnIndices.map(c => s":${c._2}" -> mapValueToAttributeValue(row(c._2), schema(c._1).dataType)).toMap.asJava + val updateItemReq = new UpdateItemRequest() + .withReturnConsumedCapacity(ReturnConsumedCapacity.TOTAL) + .withTableName(tableName) + .withKey(key.asJava) + .withUpdateExpression(updateExpression) + .withExpressionAttributeNames(nonNullColumnIndices.map(c=>s"#${c._2}" -> c._1).toMap.asJava) + .withExpressionAttributeValues(expressionAttributeValues) + + client.updateItemAsync(updateItemReq) + }) + val unitsSpent = results.map(f => (try { Option(f.get()) } catch { case _:Exception => Option.empty }) + .flatMap(c => Option(c.getConsumedCapacity)) + .map(_.getCapacityUnits) + .getOrElse(Double.box(1.0))).reduce((a,b)=>a+b) + rateLimiter.acquire(unitsSpent.toInt) }) } @@ -139,7 +144,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa }) val rateLimiter = RateLimiter.create(writeLimit max 1) - val client = getDynamoDB(region) + val client = getDynamoDB(parameters) // For each batch. items.grouped(batchSize).foreach(itemBatch => { @@ -224,12 +229,5 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa handleBatchWriteResponse(client, rateLimiter)(newResponse) } } - private def handleUpdateItemResult(rateLimiter: RateLimiter) - (result: UpdateItemResult): Unit = { - // Rate limit on write capacity. - if (result.getConsumedCapacity != null) { - rateLimiter.acquire(result.getConsumedCapacity.getCapacityUnits.toInt) - } - } } diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala index 9563918..3253fcc 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala @@ -36,7 +36,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String private val region = parameters.get("region") override val (keySchema, readLimit, itemLimit, totalSizeInBytes) = { - val table = getDynamoDB(region).getTable(tableName) + val table = getDynamoDB(parameters).getTable(tableName) val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get // Key schema. @@ -76,7 +76,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String scanSpec.withExpressionSpec(xspec.buildForScan()) } - getDynamoDB(region).getTable(tableName).getIndex(indexName).scan(scanSpec) + getDynamoDB(parameters).getTable(tableName).getIndex(indexName).scan(scanSpec) } } diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/rdd/DynamoWriteRelation.scala b/src/main/scala/com/audienceproject/spark/dynamodb/rdd/DynamoWriteRelation.scala index 61e4b81..5cfb38e 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/rdd/DynamoWriteRelation.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/rdd/DynamoWriteRelation.scala @@ -45,7 +45,7 @@ private[dynamodb] class DynamoWriteRelation(data: DataFrame, parameters: Map[Str } def update(): Unit = { - data.foreachPartition(connector.updateItems(schema) _) + data.foreachPartition(connector.updateItems(schema,batchSize) _) } diff --git a/src/test/scala/com/audienceproject/spark/dynamodb/WriteRelationTest.scala b/src/test/scala/com/audienceproject/spark/dynamodb/WriteRelationTest.scala index 015664e..5c8ed2e 100644 --- a/src/test/scala/com/audienceproject/spark/dynamodb/WriteRelationTest.scala +++ b/src/test/scala/com/audienceproject/spark/dynamodb/WriteRelationTest.scala @@ -69,7 +69,7 @@ class WriteRelationTest extends AbstractInMemoryTest { newItemsDs.write.dynamodb(tablename) newItemsDs - .withColumn("size",length($"color")) + .withColumn("si:ze",length($"color")) .drop("color") .withColumn("weight",$"weight"*2) .write.option("update","true").dynamodb(tablename) @@ -80,7 +80,7 @@ class WriteRelationTest extends AbstractInMemoryTest { assert(validationDs.select("name").as[String].collect().forall(Seq("lemon", "orange", "pomegranate") contains _)) assert(validationDs.select("color").as[String].collect().forall(Seq("yellow", "orange", "red") contains _)) assert(validationDs.select("weight").as[Double].collect().forall(Seq(0.2, 0.4, 0.4) contains _)) - assert(validationDs.select("size").as[Long].collect().forall(Seq(6,3) contains _)) + assert(validationDs.select("si:ze").as[Long].collect().forall(Seq(6,3) contains _)) }