Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit f3d2974

Browse files
committed
Merge branch 'htorrence-assume_role_option'
2 parents ee645ec + 591d604 commit f3d2974

File tree

5 files changed

+51
-15
lines changed

5 files changed

+51
-15
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Spark is used in the library as a "provided" dependency, which means Spark has t
4949
## Parameters
5050
The following parameters can be set as options on the Spark reader and writer object before loading/saving.
5151
- `region` sets the region where the dynamodb table. Default is environment specific.
52+
- `assumedArn` sets an IAM role to assume. This allows for access to a DynamoDB in a different account than the spark cluster. Defaults to the standard role configuration.
5253

5354

5455
The following parameters can be set as options on the Spark reader object before loading.

build.sbt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ organization := "com.audienceproject"
22

33
name := "spark-dynamodb"
44

5-
version := "0.4.2"
5+
version := "0.5.0"
66

77
description := "Plug-and-play implementation of an Apache Spark custom data source for AWS DynamoDB."
88

@@ -14,6 +14,7 @@ resolvers += "DynamoDBLocal" at "https://s3-us-west-2.amazonaws.com/dynamodb-loc
1414

1515
libraryDependencies += "com.amazonaws" % "aws-java-sdk-dynamodb" % "1.11.466"
1616
libraryDependencies += "com.amazonaws" % "DynamoDBLocal" % "[1.11,2.0)" % "test" exclude("com.google.guava", "guava")
17+
libraryDependencies += "com.amazonaws" % "aws-java-sdk" % "1.11.49"
1718

1819
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.0" % "provided"
1920
libraryDependencies += "com.google.guava" % "guava" % "14.0.1" % "provided"

src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,63 @@
2020
*/
2121
package com.audienceproject.spark.dynamodb.connector
2222

23-
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
23+
import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
2424
import com.amazonaws.auth.profile.ProfileCredentialsProvider
2525
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
2626
import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
2727
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
28+
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder
29+
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest
2830
import org.apache.spark.sql.sources.Filter
2931

3032
private[dynamodb] trait DynamoConnector {
3133

32-
def getDynamoDB(region: Option[String] = None): DynamoDB = {
33-
val client: AmazonDynamoDB = getDynamoDBClient(region)
34+
def getDynamoDB(region: Option[String] = None, assumedArn: Option[String] = None): DynamoDB = {
35+
val client: AmazonDynamoDB = getDynamoDBClient(region, assumedArn)
3436
new DynamoDB(client)
3537
}
3638

37-
def getDynamoDBClient(region: Option[String] = None): AmazonDynamoDB = {
39+
/**
40+
* Get credentials from a passed in arn or from profile or return the default credential provider
41+
* */
42+
private def getCredentials(chosenRegion: String, assumedArn: Option[String]) = {
43+
assumedArn.map(arn => {
44+
val stsClient = AWSSecurityTokenServiceClientBuilder
45+
.standard()
46+
.withCredentials(new DefaultAWSCredentialsProviderChain)
47+
.withRegion(chosenRegion)
48+
.build()
49+
val assumeRoleResult = stsClient.assumeRole(
50+
new AssumeRoleRequest()
51+
.withRoleSessionName("DynamoDBAssumed")
52+
.withRoleArn(arn)
53+
)
54+
val stsCredentials = assumeRoleResult.getCredentials
55+
val assumeCreds = new BasicSessionCredentials(
56+
stsCredentials.getAccessKeyId,
57+
stsCredentials.getSecretAccessKey,
58+
stsCredentials.getSessionToken
59+
)
60+
new AWSStaticCredentialsProvider(assumeCreds)
61+
}).orElse(Option(System.getProperty("aws.profile")).map(new ProfileCredentialsProvider(_)))
62+
.getOrElse(new DefaultAWSCredentialsProviderChain)
63+
}
64+
65+
def getDynamoDBClient(region: Option[String] = None, assumedArn: Option[String] = None): AmazonDynamoDB = {
3866
val chosenRegion = region.getOrElse(sys.env.getOrElse("aws.dynamodb.region", "us-east-1"))
67+
val credentials = getCredentials(chosenRegion, assumedArn)
68+
3969
Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
40-
val credentials = Option(System.getProperty("aws.profile"))
41-
.map(new ProfileCredentialsProvider(_))
42-
.getOrElse(new DefaultAWSCredentialsProviderChain)
4370
AmazonDynamoDBClientBuilder.standard()
4471
.withCredentials(credentials)
4572
.withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
4673
.build()
47-
}).getOrElse(AmazonDynamoDBClientBuilder.standard().withRegion(chosenRegion).build())
74+
}).getOrElse(
75+
AmazonDynamoDBClientBuilder.standard()
76+
.withCredentials(credentials)
77+
.withRegion(chosenRegion)
78+
.build()
79+
)
4880
}
4981

5082
val keySchema: KeySchema

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
4141
private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
4242
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
4343
private val region = parameters.get("region")
44+
private val assumedArn = parameters.get("assumedArn")
4445

4546
override val (keySchema, readLimit, writeLimit, itemLimit, totalSizeInBytes) = {
46-
val table = getDynamoDB(region).getTable(tableName)
47+
val table = getDynamoDB(region, assumedArn).getTable(tableName)
4748
val desc = table.describe()
4849

4950
// Key schema.
@@ -94,7 +95,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
9495
scanSpec.withExpressionSpec(xspec.buildForScan())
9596
}
9697

97-
getDynamoDB(region).getTable(tableName).scan(scanSpec)
98+
getDynamoDB(region, assumedArn).getTable(tableName).scan(scanSpec)
9899
}
99100

100101
override def putItems(schema: StructType, batchSize: Int)(items: Iterator[Row]): Unit = {
@@ -109,7 +110,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
109110
})
110111

111112
val rateLimiter = RateLimiter.create(writeLimit max 1)
112-
val client = getDynamoDB(region)
113+
val client = getDynamoDB(region, assumedArn)
113114

114115
// For each batch.
115116
items.grouped(batchSize).foreach(itemBatch => {
@@ -154,7 +155,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
154155
})
155156

156157
val rateLimiter = RateLimiter.create(writeLimit max 1)
157-
val client = getDynamoDB(region)
158+
val client = getDynamoDB(region, assumedArn)
158159

159160
// For each item.
160161
items.foreach(row => {

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
3434
private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
3535
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
3636
private val region = parameters.get("region")
37+
private val assumedArn = parameters.get("assumedArn")
3738

3839
override val (keySchema, readLimit, itemLimit, totalSizeInBytes) = {
39-
val table = getDynamoDB(region).getTable(tableName)
40+
val table = getDynamoDB(region, assumedArn).getTable(tableName)
4041
val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get
4142

4243
// Key schema.
@@ -81,7 +82,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
8182
scanSpec.withExpressionSpec(xspec.buildForScan())
8283
}
8384

84-
getDynamoDB(region).getTable(tableName).getIndex(indexName).scan(scanSpec)
85+
getDynamoDB(region, assumedArn).getTable(tableName).getIndex(indexName).scan(scanSpec)
8586
}
8687

8788
}

0 commit comments

Comments
 (0)