20
20
*/
21
21
package com .audienceproject .spark .dynamodb .connector
22
22
23
+ import java .util
24
+
23
25
import com .amazonaws .services .dynamodbv2 .document ._
24
- import com .amazonaws .services .dynamodbv2 .document .spec .{BatchWriteItemSpec , ScanSpec }
25
- import com .amazonaws .services .dynamodbv2 .model .{ AttributeValue , ReturnConsumedCapacity , UpdateItemRequest , UpdateItemResult }
26
+ import com .amazonaws .services .dynamodbv2 .document .spec .{BatchWriteItemSpec , ScanSpec , UpdateItemSpec }
27
+ import com .amazonaws .services .dynamodbv2 .model .ReturnConsumedCapacity
26
28
import com .amazonaws .services .dynamodbv2 .xspec .ExpressionSpecBuilder
29
+ import com .amazonaws .services .dynamodbv2 .xspec .ExpressionSpecBuilder .{BOOL => newBOOL , L => newL , M => newM , N => newN , S => newS }
27
30
import com .google .common .util .concurrent .RateLimiter
28
31
import org .apache .spark .sql .Row
29
32
import org .apache .spark .sql .sources .Filter
@@ -94,45 +97,6 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
94
97
getDynamoDB(region).getTable(tableName).scan(scanSpec)
95
98
}
96
99
97
- override def updateItems (schema : StructType )(items : Iterator [Row ]): Unit = {
98
- val columnNames = schema.map(_.name)
99
- val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
100
- val rangeKeyIndex = keySchema.rangeKeyName.map(columnNames.indexOf)
101
- val columnIndices = columnNames.zipWithIndex.filterNot({
102
- case (name, _) => keySchema match {
103
- case KeySchema (hashKey, None ) => name == hashKey
104
- case KeySchema (hashKey, Some (rangeKey)) => name == hashKey || name == rangeKey
105
- }
106
- })
107
-
108
- val rateLimiter = RateLimiter .create(writeLimit max 1 )
109
- val client = getDynamoDBClient(region)
110
-
111
- // For each item.
112
- items.foreach(row => {
113
- val key : Map [String , AttributeValue ] = keySchema match {
114
- case KeySchema (hashKey, None ) => Map (hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType))
115
- case KeySchema (hashKey, Some (rangeKey)) =>
116
- Map (hashKey -> mapValueToAttributeValue(row(hashKeyIndex), schema(hashKey).dataType),
117
- rangeKey -> mapValueToAttributeValue(row(rangeKeyIndex.get), schema(rangeKey).dataType))
118
-
119
- }
120
- val nonNullColumnIndices = columnIndices.filter(c => row(c._2) != null )
121
- val updateExpression = s " SET ${nonNullColumnIndices.map(c => s " ${c._1}=: ${c._1}" ).mkString(" , " )}"
122
- val expressionAttributeValues = nonNullColumnIndices.map(c => s " : ${c._1}" -> mapValueToAttributeValue(row(c._2), schema(c._1).dataType)).toMap.asJava
123
- val updateItemReq = new UpdateItemRequest ()
124
- .withReturnConsumedCapacity(ReturnConsumedCapacity .TOTAL )
125
- .withTableName(tableName)
126
- .withKey(key.asJava)
127
- .withUpdateExpression(updateExpression)
128
- .withExpressionAttributeValues(expressionAttributeValues)
129
-
130
- val updateItemResult = client.updateItem(updateItemReq)
131
-
132
- handleUpdateItemResult(rateLimiter)(updateItemResult)
133
- })
134
- }
135
-
136
100
override def putItems (schema : StructType , batchSize : Int )(items : Iterator [Row ]): Unit = {
137
101
val columnNames = schema.map(_.name)
138
102
val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
@@ -174,46 +138,85 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
174
138
))
175
139
176
140
val response = client.batchWriteItem(batchWriteItemSpec)
177
-
178
141
handleBatchWriteResponse(client, rateLimiter)(response)
179
142
})
180
143
}
181
144
145
+ override def updateItems (schema : StructType )(items : Iterator [Row ]): Unit = {
146
+ val columnNames = schema.map(_.name)
147
+ val hashKeyIndex = columnNames.indexOf(keySchema.hashKeyName)
148
+ val rangeKeyIndex = keySchema.rangeKeyName.map(columnNames.indexOf)
149
+ val columnIndices = columnNames.zipWithIndex.filterNot({
150
+ case (name, _) => keySchema match {
151
+ case KeySchema (hashKey, None ) => name == hashKey
152
+ case KeySchema (hashKey, Some (rangeKey)) => name == hashKey || name == rangeKey
153
+ }
154
+ })
155
+
156
+ val rateLimiter = RateLimiter .create(writeLimit max 1 )
157
+ val client = getDynamoDB(region)
158
+
159
+ // For each item.
160
+ items.foreach(row => {
161
+ // Build update expression.
162
+ val xspec = new ExpressionSpecBuilder ()
163
+ columnIndices.foreach({
164
+ case (name, index) if ! row.isNullAt(index) =>
165
+ val updateAction = schema(name).dataType match {
166
+ case StringType => newS(name).set(row.getString(index))
167
+ case BooleanType => newBOOL(name).set(row.getBoolean(index))
168
+ case IntegerType => newN(name).set(row.getInt(index))
169
+ case LongType => newN(name).set(row.getLong(index))
170
+ case ShortType => newN(name).set(row.getShort(index))
171
+ case FloatType => newN(name).set(row.getFloat(index))
172
+ case DoubleType => newN(name).set(row.getDouble(index))
173
+ case ArrayType (innerType, _) => newL(name).set(row.getSeq[Any ](index).map(e => mapValue(e, innerType)).asJava)
174
+ case MapType (keyType, valueType, _) =>
175
+ if (keyType != StringType ) throw new IllegalArgumentException (
176
+ s " Invalid Map key type ' ${keyType.typeName}'. DynamoDB only supports String as Map key type. " )
177
+ newM(name).set(row.getMap[String , Any ](index).mapValues(e => mapValue(e, valueType)).asJava)
178
+ case StructType (fields) => newM(name).set(mapStruct(row.getStruct(index), fields))
179
+ }
180
+ xspec.addUpdate(updateAction)
181
+ case _ =>
182
+ })
183
+
184
+ val updateItemSpec = new UpdateItemSpec ()
185
+ .withExpressionSpec(xspec.buildForUpdate())
186
+ .withReturnConsumedCapacity(ReturnConsumedCapacity .TOTAL )
187
+
188
+ // Map primary key.
189
+ keySchema match {
190
+ case KeySchema (hashKey, None ) => updateItemSpec.withPrimaryKey(hashKey, row(hashKeyIndex))
191
+ case KeySchema (hashKey, Some (rangeKey)) =>
192
+ updateItemSpec.withPrimaryKey(hashKey, row(hashKeyIndex), rangeKey, row(rangeKeyIndex.get))
193
+ }
194
+
195
+ if (updateItemSpec.getUpdateExpression.nonEmpty) {
196
+ val response = client.getTable(tableName).updateItem(updateItemSpec)
197
+ handleUpdateResponse(rateLimiter)(response)
198
+ }
199
+ })
200
+ }
201
+
182
202
private def mapValue (element : Any , elementType : DataType ): Any = {
183
203
elementType match {
184
204
case ArrayType (innerType, _) => element.asInstanceOf [Seq [_]].map(e => mapValue(e, innerType)).asJava
185
205
case MapType (keyType, valueType, _) =>
186
206
if (keyType != StringType ) throw new IllegalArgumentException (
187
207
s " Invalid Map key type ' ${keyType.typeName}'. DynamoDB only supports String as Map key type. " )
188
- element.asInstanceOf [Map [_ , _]].mapValues(e => mapValue(e, valueType)).asJava
208
+ element.asInstanceOf [Map [String , _]].mapValues(e => mapValue(e, valueType)).asJava
189
209
case StructType (fields) =>
190
210
val row = element.asInstanceOf [Row ]
191
- (fields.indices map { i =>
192
- fields(i).name -> mapValue(row(i), fields(i).dataType)
193
- }).toMap.asJava
211
+ mapStruct(row, fields)
194
212
case _ => element
195
213
}
196
214
}
197
215
198
- private def mapValueToAttributeValue (element : Any , elementType : DataType ): AttributeValue = {
199
- elementType match {
200
- case ArrayType (innerType, _) => new AttributeValue ().withL(element.asInstanceOf [Seq [_]].map(e => mapValueToAttributeValue(e, innerType)): _* )
201
- case MapType (keyType, valueType, _) =>
202
- if (keyType != StringType ) throw new IllegalArgumentException (
203
- s " Invalid Map key type ' ${keyType.typeName}'. DynamoDB only supports String as Map key type. " )
204
-
205
- new AttributeValue ().withM(element.asInstanceOf [Map [String , _]].mapValues(e => mapValueToAttributeValue(e, valueType)).asJava)
206
-
207
- case StructType (fields) =>
208
- val row = element.asInstanceOf [Row ]
209
- new AttributeValue ().withM((fields.indices map { i =>
210
- fields(i).name -> mapValueToAttributeValue(row(i), fields(i).dataType)
211
- }).toMap.asJava)
212
- case StringType => new AttributeValue ().withS(element.asInstanceOf [String ])
213
- case LongType | IntegerType | DoubleType | FloatType => new AttributeValue ().withN(element.toString)
214
- case BooleanType => new AttributeValue ().withBOOL(element.asInstanceOf [Boolean ])
215
- }
216
- }
216
+ private def mapStruct (row : Row , fields : Seq [StructField ]): util.Map [String , Any ] =
217
+ (fields.indices map { i =>
218
+ fields(i).name -> mapValue(row(i), fields(i).dataType)
219
+ }).toMap.asJava
217
220
218
221
@ tailrec
219
222
private def handleBatchWriteResponse (client : DynamoDB , rateLimiter : RateLimiter )
@@ -231,12 +234,12 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
231
234
}
232
235
}
233
236
234
- private def handleUpdateItemResult (rateLimiter : RateLimiter )
235
- ( result : UpdateItemResult ): Unit = {
237
+ private def handleUpdateResponse (rateLimiter : RateLimiter )
238
+ ( response : UpdateItemOutcome ): Unit = {
236
239
// Rate limit on write capacity.
237
- if (result. getConsumedCapacity != null ) {
238
- rateLimiter.acquire(result.getConsumedCapacity .getCapacityUnits.toInt)
239
- }
240
+ Option (response.getUpdateItemResult. getConsumedCapacity)
241
+ .map(_ .getCapacityUnits.toInt)
242
+ .foreach(rateLimiter.acquire)
240
243
}
241
244
242
245
}
0 commit comments