@@ -24,7 +24,7 @@ import java.util
24
24
25
25
import com .amazonaws .services .dynamodbv2 .document ._
26
26
import com .amazonaws .services .dynamodbv2 .document .spec .{BatchWriteItemSpec , ScanSpec , UpdateItemSpec }
27
- import com .amazonaws .services .dynamodbv2 .model .ReturnConsumedCapacity
27
+ import com .amazonaws .services .dynamodbv2 .model .{ AttributeValue , ReturnConsumedCapacity , UpdateItemRequest }
28
28
import com .amazonaws .services .dynamodbv2 .xspec .ExpressionSpecBuilder
29
29
import com .amazonaws .services .dynamodbv2 .xspec .ExpressionSpecBuilder .{BOOL => newBOOL , L => newL , M => newM , N => newN , S => newS }
30
30
import com .google .common .util .concurrent .RateLimiter
@@ -143,7 +143,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
143
143
})
144
144
}
145
145
146
- override def updateItems (schema : StructType )(items : Iterator [Row ]): Unit = {
146
+ override def updateItems (schema : StructType , batchSize : Int )(items : Iterator [Row ]): Unit = {
147
147
val columnNames = schema.map(_.name)
148
148
val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
149
149
val rangeKeyIndex = keySchema.rangeKeyName.map(columnNames.indexOf)
@@ -155,48 +155,38 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
155
155
})
156
156
157
157
val rateLimiter = RateLimiter .create(writeLimit max 1 )
158
- val client = getDynamoDB(region, roleArn)
159
-
160
- // For each item.
161
- items.foreach(row => {
162
- // Build update expression.
163
- val xspec = new ExpressionSpecBuilder ()
164
- columnIndices.foreach({
165
- case (name, index) if ! row.isNullAt(index) =>
166
- val updateAction = schema(name).dataType match {
167
- case StringType => newS(name).set(row.getString(index))
168
- case BooleanType => newBOOL(name).set(row.getBoolean(index))
169
- case IntegerType => newN(name).set(row.getInt(index))
170
- case LongType => newN(name).set(row.getLong(index))
171
- case ShortType => newN(name).set(row.getShort(index))
172
- case FloatType => newN(name).set(row.getFloat(index))
173
- case DoubleType => newN(name).set(row.getDouble(index))
174
- case ArrayType (innerType, _) => newL(name).set(row.getSeq[Any ](index).map(e => mapValue(e, innerType)).asJava)
175
- case MapType (keyType, valueType, _) =>
176
- if (keyType != StringType ) throw new IllegalArgumentException (
177
- s " Invalid Map key type ' ${keyType.typeName}'. DynamoDB only supports String as Map key type. " )
178
- newM(name).set(row.getMap[String , Any ](index).mapValues(e => mapValue(e, valueType)).asJava)
179
- case StructType (fields) => newM(name).set(mapStruct(row.getStruct(index), fields))
180
- }
181
- xspec.addUpdate(updateAction)
182
- case _ =>
183
- })
158
+ val client = getDynamoDBAsyncClient(region,roleArn)
184
159
185
- val updateItemSpec = new UpdateItemSpec ()
186
- .withExpressionSpec(xspec.buildForUpdate())
187
- .withReturnConsumedCapacity(ReturnConsumedCapacity .TOTAL )
188
160
189
- // Map primary key.
190
- keySchema match {
191
- case KeySchema (hashKey, None ) => updateItemSpec.withPrimaryKey(hashKey, row(hashKeyIndex))
192
- case KeySchema (hashKey, Some (rangeKey)) =>
193
- updateItemSpec.withPrimaryKey(hashKey, row(hashKeyIndex), rangeKey, row(rangeKeyIndex.get))
194
- }
195
161
196
- if (updateItemSpec.getUpdateExpression.nonEmpty) {
197
- val response = client.getTable(tableName).updateItem(updateItemSpec)
198
- handleUpdateResponse(rateLimiter)(response)
199
- }
162
+ // For each item.
163
+ items.grouped(batchSize).foreach(itemBatch => {
164
+ val results = itemBatch.map(row => {
165
+ val key : Map [String ,AttributeValue ] = keySchema match {
166
+ case KeySchema (hashKey, None ) => Map (hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType))
167
+ case KeySchema (hashKey, Some (rangeKey)) =>
168
+ Map (hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType),
169
+ rangeKey-> mapValueToAttributeValue(row(rangeKeyIndex.get), schema(rangeKey).dataType))
170
+
171
+ }
172
+ val nonNullColumnIndices = columnIndices.filter(c => row(c._2)!= null )
173
+ val updateExpression = s " SET ${nonNullColumnIndices.map(c => s " # ${c._2}=: ${c._2}" ).mkString(" , " )}"
174
+ val expressionAttributeValues = nonNullColumnIndices.map(c => s " : ${c._2}" -> mapValueToAttributeValue(row(c._2), schema(c._1).dataType)).toMap.asJava
175
+ val updateItemReq = new UpdateItemRequest ()
176
+ .withReturnConsumedCapacity(ReturnConsumedCapacity .TOTAL )
177
+ .withTableName(tableName)
178
+ .withKey(key.asJava)
179
+ .withUpdateExpression(updateExpression)
180
+ .withExpressionAttributeNames(nonNullColumnIndices.map(c=> s " # ${c._2}" -> c._1).toMap.asJava)
181
+ .withExpressionAttributeValues(expressionAttributeValues)
182
+
183
+ client.updateItemAsync(updateItemReq)
184
+ })
185
+ val unitsSpent = results.map(f => (try { Option (f.get()) } catch { case _:Exception => Option .empty })
186
+ .flatMap(c => Option (c.getConsumedCapacity))
187
+ .map(_.getCapacityUnits)
188
+ .getOrElse(Double .box(1.0 ))).reduce((a,b)=> a+ b)
189
+ rateLimiter.acquire(unitsSpent.toInt)
200
190
})
201
191
}
202
192
@@ -214,6 +204,26 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
214
204
}
215
205
}
216
206
207
+ private def mapValueToAttributeValue (element : Any , elementType : DataType ): AttributeValue = {
208
+ elementType match {
209
+ case ArrayType (innerType, _) => new AttributeValue ().withL(element.asInstanceOf [Seq [_]].map(e => mapValueToAttributeValue(e, innerType)):_* )
210
+ case MapType (keyType, valueType, _) =>
211
+ if (keyType != StringType ) throw new IllegalArgumentException (
212
+ s " Invalid Map key type ' ${keyType.typeName}'. DynamoDB only supports String as Map key type. " )
213
+
214
+ new AttributeValue ().withM(element.asInstanceOf [Map [String , _]].mapValues(e => mapValueToAttributeValue(e, valueType)).asJava)
215
+
216
+ case StructType (fields) =>
217
+ val row = element.asInstanceOf [Row ]
218
+ new AttributeValue ().withM( (fields.indices map { i =>
219
+ fields(i).name -> mapValueToAttributeValue(row(i), fields(i).dataType)
220
+ }).toMap.asJava)
221
+ case StringType => new AttributeValue ().withS(element.asInstanceOf [String ])
222
+ case LongType | IntegerType | DoubleType | FloatType => new AttributeValue ().withN(element.toString)
223
+ case BooleanType => new AttributeValue ().withBOOL(element.asInstanceOf [Boolean ])
224
+ }
225
+ }
226
+
217
227
private def mapStruct (row : Row , fields : Seq [StructField ]): util.Map [String , Any ] =
218
228
(fields.indices map { i =>
219
229
fields(i).name -> mapValue(row(i), fields(i).dataType)
0 commit comments