Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit 3334458

Browse files
committed
assume role option
1 parent ee645ec commit 3334458

File tree

4 files changed

+37
-12
lines changed

4 files changed

+37
-12
lines changed

build.sbt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ organization := "com.audienceproject"
22

33
name := "spark-dynamodb"
44

5-
version := "0.4.2"
5+
version := "0.5.0"
66

77
description := "Plug-and-play implementation of an Apache Spark custom data source for AWS DynamoDB."
88

@@ -14,6 +14,7 @@ resolvers += "DynamoDBLocal" at "https://s3-us-west-2.amazonaws.com/dynamodb-loc
1414

1515
libraryDependencies += "com.amazonaws" % "aws-java-sdk-dynamodb" % "1.11.466"
1616
libraryDependencies += "com.amazonaws" % "DynamoDBLocal" % "[1.11,2.0)" % "test" exclude("com.google.guava", "guava")
17+
libraryDependencies += "com.amazonaws" % "aws-java-sdk" % "1.11.49"
1718

1819
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.0" % "provided"
1920
libraryDependencies += "com.google.guava" % "guava" % "14.0.1" % "provided"

src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,23 @@
2020
*/
2121
package com.audienceproject.spark.dynamodb.connector
2222

23-
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
23+
import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
2424
import com.amazonaws.auth.profile.ProfileCredentialsProvider
2525
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
2626
import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
2727
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
28+
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder
29+
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest
2830
import org.apache.spark.sql.sources.Filter
2931

3032
private[dynamodb] trait DynamoConnector {
3133

32-
def getDynamoDB(region: Option[String] = None): DynamoDB = {
33-
val client: AmazonDynamoDB = getDynamoDBClient(region)
34+
def getDynamoDB(region: Option[String] = None, assumedArn: Option[String] = None): DynamoDB = {
35+
val client: AmazonDynamoDB = getDynamoDBClient(region, assumedArn)
3436
new DynamoDB(client)
3537
}
3638

37-
def getDynamoDBClient(region: Option[String] = None): AmazonDynamoDB = {
39+
def getDynamoDBClient(region: Option[String] = None, assumedArn: Option[String] = None): AmazonDynamoDB = {
3840
val chosenRegion = region.getOrElse(sys.env.getOrElse("aws.dynamodb.region", "us-east-1"))
3941
Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
4042
val credentials = Option(System.getProperty("aws.profile"))
@@ -44,7 +46,27 @@ private[dynamodb] trait DynamoConnector {
4446
.withCredentials(credentials)
4547
.withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
4648
.build()
47-
}).getOrElse(AmazonDynamoDBClientBuilder.standard().withRegion(chosenRegion).build())
49+
}).getOrElse(
50+
assumedArn.map(arn => {
51+
val stsClient = AWSSecurityTokenServiceClientBuilder
52+
.standard()
53+
.withCredentials(new DefaultAWSCredentialsProviderChain)
54+
.withRegion(chosenRegion)
55+
.build()
56+
val assumeRoleResult = stsClient.assumeRole(
57+
new AssumeRoleRequest()
58+
.withRoleSessionName("DynamoDBAssumed")
59+
.withRoleArn(arn)
60+
)
61+
val stsCredentials = assumeRoleResult.getCredentials
62+
val assumeCreds = new BasicSessionCredentials(
63+
stsCredentials.getAccessKeyId,
64+
stsCredentials.getSecretAccessKey,
65+
stsCredentials.getSessionToken
66+
)
67+
AmazonDynamoDBClientBuilder.standard().withCredentials(new AWSStaticCredentialsProvider(assumeCreds)).build()
68+
}).getOrElse(AmazonDynamoDBClientBuilder.standard().withRegion(chosenRegion).build())
69+
)
4870
}
4971

5072
val keySchema: KeySchema

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
4141
private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
4242
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
4343
private val region = parameters.get("region")
44+
private val assumedArn = parameters.get("assumedArn")
4445

4546
override val (keySchema, readLimit, writeLimit, itemLimit, totalSizeInBytes) = {
46-
val table = getDynamoDB(region).getTable(tableName)
47+
val table = getDynamoDB(region, assumedArn).getTable(tableName)
4748
val desc = table.describe()
4849

4950
// Key schema.
@@ -94,7 +95,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
9495
scanSpec.withExpressionSpec(xspec.buildForScan())
9596
}
9697

97-
getDynamoDB(region).getTable(tableName).scan(scanSpec)
98+
getDynamoDB(region, assumedArn).getTable(tableName).scan(scanSpec)
9899
}
99100

100101
override def putItems(schema: StructType, batchSize: Int)(items: Iterator[Row]): Unit = {
@@ -109,7 +110,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
109110
})
110111

111112
val rateLimiter = RateLimiter.create(writeLimit max 1)
112-
val client = getDynamoDB(region)
113+
val client = getDynamoDB(region, assumedArn)
113114

114115
// For each batch.
115116
items.grouped(batchSize).foreach(itemBatch => {
@@ -154,7 +155,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
154155
})
155156

156157
val rateLimiter = RateLimiter.create(writeLimit max 1)
157-
val client = getDynamoDB(region)
158+
val client = getDynamoDB(region, assumedArn)
158159

159160
// For each item.
160161
items.foreach(row => {

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
3434
private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
3535
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
3636
private val region = parameters.get("region")
37+
private val assumedArn = parameters.get("assumedArn")
3738

3839
override val (keySchema, readLimit, itemLimit, totalSizeInBytes) = {
39-
val table = getDynamoDB(region).getTable(tableName)
40+
val table = getDynamoDB(region, assumedArn).getTable(tableName)
4041
val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get
4142

4243
// Key schema.
@@ -81,7 +82,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
8182
scanSpec.withExpressionSpec(xspec.buildForScan())
8283
}
8384

84-
getDynamoDB(region).getTable(tableName).getIndex(indexName).scan(scanSpec)
85+
getDynamoDB(region, assumedArn).getTable(tableName).getIndex(indexName).scan(scanSpec)
8586
}
8687

8788
}

0 commit comments

Comments
 (0)