diff --git a/README.md b/README.md index a43ad7e..1f6a9df 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ val avgWeightByColor = vegetableDs.agg($"color", avg($"weightKg")) // The column ```python # Load a DataFrame from a Dynamo table. Only incurs the cost of a single scan for schema inference. dynamoDf = spark.read.option("tableName", "SomeTableName") \ + .mode(SaveMode.Append) \ .format("dynamodb") \ .load() # <-- DataFrame of Row objects with inferred schema. @@ -70,6 +71,7 @@ dynamoDf.show(100) # write to some other table overwriting existing item with same keys dynamoDf.write.option("tableName", "SomeOtherTable") \ + .mode(SaveMode.Append) \ .format("dynamodb") \ .save() ``` @@ -83,25 +85,35 @@ pyspark --packages com.audienceproject:spark-dynamodb_: { - AmazonDynamoDBClientBuilder.standard() - .withCredentials(credentials) - .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion)) - .build() - }).getOrElse( - AmazonDynamoDBClientBuilder.standard() + if (omitDax || daxEndpoint.isEmpty) { + logger.info("NOT using DAX") + properties.get("aws.dynamodb.endpoint").map(endpoint => { + logger.debug(s"Using DynamoDB endpoint ${endpoint}") + AmazonDynamoDBClientBuilder.standard() + .withCredentials(credentials) + .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion)) + .build() + }).getOrElse( + AmazonDynamoDBClientBuilder.standard() + .withCredentials(credentials) + .withRegion(chosenRegion) + .build() + ) + } else { + logger.debug(s"Using DAX endpoint ${daxEndpoint}") + AmazonDaxClientBuilder.standard() + .withEndpointConfiguration(daxEndpoint) .withCredentials(credentials) .withRegion(chosenRegion) .build() - ) + } + } def getDynamoDBAsyncClient(region: Option[String] = None, roleArn: Option[String] = None, - providerClassName: Option[String] = None): AmazonDynamoDBAsync = { + providerClassName: Option[String] = None, + omitDax: Boolean = false): AmazonDynamoDBAsync = { val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1")) val credentials = getCredentials(chosenRegion, roleArn, providerClassName) - properties.get("aws.dynamodb.endpoint").map(endpoint => { - AmazonDynamoDBAsyncClientBuilder.standard() - .withCredentials(credentials) - .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion)) - .build() - }).getOrElse( - AmazonDynamoDBAsyncClientBuilder.standard() + if (omitDax || daxEndpoint.isEmpty) { + properties.get("aws.dynamodb.endpoint").map(endpoint => { + logger.debug(s"Using DynamoDB endpoint ${endpoint}") + AmazonDynamoDBAsyncClientBuilder.standard() + .withCredentials(credentials) + .withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion)) + .build() + }).getOrElse( + AmazonDynamoDBAsyncClientBuilder.standard() + .withCredentials(credentials) + .withRegion(chosenRegion) + .build() + ) + } else { + logger.debug(s"Using DAX endpoint ${daxEndpoint}") + AmazonDaxAsyncClientBuilder.standard() + .withEndpointConfiguration(daxEndpoint) .withCredentials(credentials) .withRegion(chosenRegion) .build() - ) + } } /** @@ -126,6 +156,8 @@ private[dynamodb] trait DynamoConnector { val filterPushdownEnabled: Boolean + val daxEndpoint: String + def scan(segmentNum: Int, columns: Seq[String], filters: Seq[Filter]): ItemCollection[ScanOutcome] def isEmpty: Boolean = itemLimit == 0 diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala index 80e2bc4..058e5f1 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala @@ -32,6 +32,12 @@ import org.apache.spark.sql.sources.Filter import scala.annotation.tailrec import scala.collection.JavaConverters._ +/** + * + * @param tableName + * @param parallelism + * @param parameters case sensitive Map, all keys have been lowercased + */ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, parameters: Map[String, String]) extends DynamoConnector with DynamoWritable with Serializable { @@ -43,8 +49,10 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para override val filterPushdownEnabled: Boolean = filterPushdown + override val daxEndpoint: String = parameters.getOrElse("daxendpoint", "").trim + override val (keySchema, readLimit, writeLimit, itemLimit, totalSegments) = { - val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName) + val table = getDynamoDB(region, roleArn, providerClassName, omitDax = true).getTable(tableName) val desc = table.describe() // Key schema. diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala index 0d1c213..882de0c 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala @@ -39,6 +39,8 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String override val filterPushdownEnabled: Boolean = filterPushdown + override val daxEndpoint: String = parameters.getOrElse("daxendpoint", "").trim + override val (keySchema, readLimit, itemLimit, totalSegments) = { val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName) val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataWriter.scala b/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataWriter.scala index 14c4610..d9c44ed 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataWriter.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoDataWriter.scala @@ -51,7 +51,10 @@ class DynamoDataWriter(batchSize: Int, override def abort(): Unit = {} - override def close(): Unit = client.shutdown() + override def close(): Unit = { + buffer.clear() + client.shutdown() + } protected def flush(): Unit = { if (buffer.nonEmpty) {