diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala index d2fbb41..7e7251e 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala @@ -28,6 +28,7 @@ import com.audienceproject.shaded.google.common.util.concurrent.RateLimiter import com.audienceproject.spark.dynamodb.catalyst.JavaConverter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.Filter +import org.slf4j.LoggerFactory import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -39,9 +40,9 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para private val filterPushdown = parameters.getOrElse("filterpushdown", "true").toBoolean private val region = parameters.get("region") private val roleArn = parameters.get("rolearn") - + private val maxRetries = parameters.getOrElse("maxretries", "3").toInt override val filterPushdownEnabled: Boolean = filterPushdown - + private val logger = LoggerFactory.getLogger(this.getClass) override val (keySchema, readLimit, writeLimit, itemLimit, totalSegments) = { val table = getDynamoDB(region, roleArn).getTable(tableName) val desc = table.describe() @@ -54,7 +55,10 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para val maxPartitionBytes = parameters.getOrElse("maxpartitionbytes", "128000000").toInt val targetCapacity = parameters.getOrElse("targetcapacity", "1").toDouble val readFactor = if (consistentRead) 1 else 2 - + //Write parallelisation parameter. depends on number of input partitions the Data Frame is distributed in. + //This can be passed by using numInputDFPartitions option. + //By default it is chosen tobe spark's default parallelism + val numTasks = parameters.getOrElse("numinputdfpartitions", parallelism.toString).toInt // Table parameters. val tableSize = desc.getTableSizeBytes val itemCount = desc.getItemCount @@ -66,25 +70,29 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para if (remainder > 0) sizeBased + (parallelism - remainder) else sizeBased }) - - // Provisioned or on-demand throughput. - val readThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getReadCapacityUnits) - .filter(_ > 0).map(_.longValue().toString) - .getOrElse("100")).toLong - val writeThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getWriteCapacityUnits) - .filter(_ > 0).map(_.longValue().toString) - .getOrElse("100")).toLong - - // Rate limit calculation. - val avgItemSize = tableSize.toDouble / itemCount - val readCapacity = readThroughput * targetCapacity - val writeCapacity = writeThroughput * targetCapacity - + //If information about absolute throughput is provided + var readCapacity = parameters.getOrElse("absread", "-1").toDouble + var writeCapacity = parameters.getOrElse("abswrite", "-1").toDouble + // Else + if(readCapacity < 0) { + val readThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getReadCapacityUnits) + .filter(_ > 0).map(_.longValue().toString) + .getOrElse("100")).toLong + readCapacity = readThroughput * targetCapacity + } + if(writeCapacity < 0) { + val writeThroughput = parameters.getOrElse("throughput", Option(desc.getProvisionedThroughput.getWriteCapacityUnits) + .filter(_ > 0).map(_.longValue().toString) + .getOrElse("100")).toLong + // Rate limit calculation. + writeCapacity = writeThroughput * targetCapacity + } + //Calculating write limit for each task, based on number of parallel tasks, target capacity, and WCU limit + val writeLimit = writeCapacity / numTasks val readLimit = readCapacity / parallelism + val avgItemSize = tableSize.toDouble / itemCount val itemLimit = ((bytesPerRCU / avgItemSize * readLimit).toInt * readFactor) max 1 - val writeLimit = writeCapacity / parallelism - (keySchema, readLimit, writeLimit, itemLimit, numPartitions) } @@ -140,7 +148,7 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para )) val response = client.batchWriteItem(batchWriteItemSpec) - handleBatchWriteResponse(client, rateLimiter)(response) + handleBatchWriteResponse(client, rateLimiter)(response, 0) } override def updateItem(columnSchema: ColumnSchema, row: InternalRow) @@ -196,12 +204,12 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para batchWriteItemSpec.withTableWriteItems(tableWriteItemsWithItems) val response = client.batchWriteItem(batchWriteItemSpec) - handleBatchWriteResponse(client, rateLimiter)(response) + handleBatchWriteResponse(client, rateLimiter)(response, 0) } @tailrec private def handleBatchWriteResponse(client: DynamoDB, rateLimiter: RateLimiter) - (response: BatchWriteItemOutcome): Unit = { + (response: BatchWriteItemOutcome, retries: Int): Unit = { // Rate limit on write capacity. if (response.getBatchWriteItemResult.getConsumedCapacity != null) { response.getBatchWriteItemResult.getConsumedCapacity.asScala.map(cap => { @@ -210,8 +218,21 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para } // Retry unprocessed items. if (response.getUnprocessedItems != null && !response.getUnprocessedItems.isEmpty) { - val newResponse = client.batchWriteItemUnprocessed(response.getUnprocessedItems) - handleBatchWriteResponse(client, rateLimiter)(newResponse) + println("Unprocessed items found") + if (retries < maxRetries) { + val newResponse = client.batchWriteItemUnprocessed(response.getUnprocessedItems) + handleBatchWriteResponse(client, rateLimiter)(newResponse, retries + 1) + } + else{ + val unprocessed = response.getUnprocessedItems + //logging about unprocessed items + unprocessed.asScala.foreach(keyValue => + logger.info("Maximum retiries reached while writing items to the DynamoDB." + + "Number of unprocessed items of table \"" + keyValue._1 +"\" = " + + keyValue._2.asScala.length) + ) + + } } } diff --git a/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoBatchWriter.scala b/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoBatchWriter.scala index 305b47f..946eb5c 100644 --- a/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoBatchWriter.scala +++ b/src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoBatchWriter.scala @@ -36,7 +36,6 @@ class DynamoBatchWriter(batchSize: Int, protected val buffer: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow](batchSize) protected val rateLimiter: RateLimiter = RateLimiter.create(connector.writeLimit) - override def write(record: InternalRow): Unit = { buffer += record.copy() if (buffer.size == batchSize) {