Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Added option for choosing absolute read, write throughput (cumulative). Issue with Parallelism while writing is handled. #66

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter
import com.audienceproject.spark.dynamodb.catalyst.JavaConverter
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.Filter
import org.slf4j.LoggerFactory

import scala.annotation.tailrec
import scala.collection.JavaConverters._
Expand All @@ -39,9 +40,9 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
private val filterPushdown = parameters.getOrElse("filterpushdown", "true").toBoolean
private val region = parameters.get("region")
private val roleArn = parameters.get("rolearn")

private val maxRetries = parameters.getOrElse("maxretries", "3").toInt
override val filterPushdownEnabled: Boolean = filterPushdown

private val logger = LoggerFactory.getLogger(this.getClass)
override val (keySchema, readLimit, writeLimit, itemLimit, totalSegments) = {
val table = getDynamoDB(region, roleArn).getTable(tableName)
val desc = table.describe()
Expand All @@ -54,7 +55,10 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
val maxPartitionBytes = parameters.getOrElse("maxpartitionbytes", "128000000").toInt
val targetCapacity = parameters.getOrElse("targetcapacity", "1").toDouble
val readFactor = if (consistentRead) 1 else 2

//Write parallelisation parameter. depends on number of input partitions the Data Frame is distributed in.
//This can be passed by using numInputDFPartitions option.
//By default it is chosen tobe spark's default parallelism
val numTasks = parameters.getOrElse("numinputdfpartitions", parallelism.toString).toInt
// Table parameters.
val tableSize = desc.getTableSizeBytes
val itemCount = desc.getItemCount
Expand All @@ -66,25 +70,29 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
if (remainder > 0) sizeBased + (parallelism - remainder)
else sizeBased
})

// Provisioned or on-demand throughput.
val readThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getReadCapacityUnits)
.filter(_ > 0).map(_.longValue().toString)
.getOrElse("100")).toLong
val writeThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getWriteCapacityUnits)
.filter(_ > 0).map(_.longValue().toString)
.getOrElse("100")).toLong

// Rate limit calculation.
val avgItemSize = tableSize.toDouble / itemCount
val readCapacity = readThroughput * targetCapacity
val writeCapacity = writeThroughput * targetCapacity

//If information about absolute throughput is provided
var readCapacity = parameters.getOrElse("absread", "-1").toDouble
var writeCapacity = parameters.getOrElse("abswrite", "-1").toDouble
// Else
if(readCapacity < 0) {
val readThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getReadCapacityUnits)
.filter(_ > 0).map(_.longValue().toString)
.getOrElse("100")).toLong
readCapacity = readThroughput * targetCapacity
}
if(writeCapacity < 0) {
val writeThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getWriteCapacityUnits)
.filter(_ > 0).map(_.longValue().toString)
.getOrElse("100")).toLong
// Rate limit calculation.
writeCapacity = writeThroughput * targetCapacity
}
//Calculating write limit for each task, based on number of parallel tasks, target capacity, and WCU limit
val writeLimit = writeCapacity / numTasks
val readLimit = readCapacity / parallelism
val avgItemSize = tableSize.toDouble / itemCount
val itemLimit = ((bytesPerRCU / avgItemSize * readLimit).toInt * readFactor) max 1

val writeLimit = writeCapacity / parallelism

(keySchema, readLimit, writeLimit, itemLimit, numPartitions)
}

Expand Down Expand Up @@ -140,7 +148,7 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
))

val response = client.batchWriteItem(batchWriteItemSpec)
handleBatchWriteResponse(client, rateLimiter)(response)
handleBatchWriteResponse(client, rateLimiter)(response, 0)
}

override def updateItem(columnSchema: ColumnSchema, row: InternalRow)
Expand Down Expand Up @@ -196,12 +204,12 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
batchWriteItemSpec.withTableWriteItems(tableWriteItemsWithItems)

val response = client.batchWriteItem(batchWriteItemSpec)
handleBatchWriteResponse(client, rateLimiter)(response)
handleBatchWriteResponse(client, rateLimiter)(response, 0)
}

@tailrec
private def handleBatchWriteResponse(client: DynamoDB, rateLimiter: RateLimiter)
(response: BatchWriteItemOutcome): Unit = {
(response: BatchWriteItemOutcome, retries: Int): Unit = {
// Rate limit on write capacity.
if (response.getBatchWriteItemResult.getConsumedCapacity != null) {
response.getBatchWriteItemResult.getConsumedCapacity.asScala.map(cap => {
Expand All @@ -210,8 +218,21 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
}
// Retry unprocessed items.
if (response.getUnprocessedItems != null && !response.getUnprocessedItems.isEmpty) {
val newResponse = client.batchWriteItemUnprocessed(response.getUnprocessedItems)
handleBatchWriteResponse(client, rateLimiter)(newResponse)
println("Unprocessed items found")
if (retries < maxRetries) {
val newResponse = client.batchWriteItemUnprocessed(response.getUnprocessedItems)
handleBatchWriteResponse(client, rateLimiter)(newResponse, retries + 1)
}
else{
val unprocessed = response.getUnprocessedItems
//logging about unprocessed items
unprocessed.asScala.foreach(keyValue =>
logger.info("Maximum retiries reached while writing items to the DynamoDB." +
"Number of unprocessed items of table \"" + keyValue._1 +"\" = " +
keyValue._2.asScala.length)
)

}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class DynamoBatchWriter(batchSize: Int,

protected val buffer: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow](batchSize)
protected val rateLimiter: RateLimiter = RateLimiter.create(connector.writeLimit)

override def write(record: InternalRow): Unit = {
buffer += record.copy()
if (buffer.size == batchSize) {
Expand Down