Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit ff7780f

Browse files
committed
Merge PR htorrence:assume_role_option
1 parent f3d2974 commit ff7780f

File tree

5 files changed

+36
-36
lines changed

5 files changed

+36
-36
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Spark is used in the library as a "provided" dependency, which means Spark has t
4949
## Parameters
5050
The following parameters can be set as options on the Spark reader and writer object before loading/saving.
5151
- `region` sets the region where the dynamodb table. Default is environment specific.
52-
- `assumedArn` sets an IAM role to assume. This allows for access to a DynamoDB in a different account than the spark cluster. Defaults to the standard role configuration.
52+
- `roleArn` sets an IAM role to assume. This allows for access to a DynamoDB in a different account than the Spark cluster. Defaults to the standard role configuration.
5353

5454

5555
The following parameters can be set as options on the Spark reader object before loading.

build.sbt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ organization := "com.audienceproject"
22

33
name := "spark-dynamodb"
44

5-
version := "0.5.0"
5+
version := "0.4.3"
66

77
description := "Plug-and-play implementation of an Apache Spark custom data source for AWS DynamoDB."
88

@@ -12,9 +12,9 @@ crossScalaVersions := Seq("2.11.12", "2.12.7")
1212

1313
resolvers += "DynamoDBLocal" at "https://s3-us-west-2.amazonaws.com/dynamodb-local/release"
1414

15-
libraryDependencies += "com.amazonaws" % "aws-java-sdk-dynamodb" % "1.11.466"
15+
libraryDependencies += "com.amazonaws" % "aws-java-sdk-sts" % "1.11.571"
16+
libraryDependencies += "com.amazonaws" % "aws-java-sdk-dynamodb" % "1.11.571"
1617
libraryDependencies += "com.amazonaws" % "DynamoDBLocal" % "[1.11,2.0)" % "test" exclude("com.google.guava", "guava")
17-
libraryDependencies += "com.amazonaws" % "aws-java-sdk" % "1.11.49"
1818

1919
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.0" % "provided"
2020
libraryDependencies += "com.google.guava" % "guava" % "14.0.1" % "provided"

src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
*/
2121
package com.audienceproject.spark.dynamodb.connector
2222

23-
import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
2423
import com.amazonaws.auth.profile.ProfileCredentialsProvider
24+
import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
2525
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
2626
import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
2727
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBClientBuilder}
@@ -31,16 +31,33 @@ import org.apache.spark.sql.sources.Filter
3131

3232
private[dynamodb] trait DynamoConnector {
3333

34-
def getDynamoDB(region: Option[String] = None, assumedArn: Option[String] = None): DynamoDB = {
35-
val client: AmazonDynamoDB = getDynamoDBClient(region, assumedArn)
34+
def getDynamoDB(region: Option[String] = None, roleArn: Option[String] = None): DynamoDB = {
35+
val client: AmazonDynamoDB = getDynamoDBClient(region, roleArn)
3636
new DynamoDB(client)
3737
}
3838

39+
private def getDynamoDBClient(region: Option[String] = None, roleArn: Option[String] = None): AmazonDynamoDB = {
40+
val chosenRegion = region.getOrElse(sys.env.getOrElse("aws.dynamodb.region", "us-east-1"))
41+
val credentials = getCredentials(chosenRegion, roleArn)
42+
43+
Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
44+
AmazonDynamoDBClientBuilder.standard()
45+
.withCredentials(credentials)
46+
.withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
47+
.build()
48+
}).getOrElse(
49+
AmazonDynamoDBClientBuilder.standard()
50+
.withCredentials(credentials)
51+
.withRegion(chosenRegion)
52+
.build()
53+
)
54+
}
55+
3956
/**
4057
* Get credentials from a passed in arn or from profile or return the default credential provider
41-
* */
42-
private def getCredentials(chosenRegion: String, assumedArn: Option[String]) = {
43-
assumedArn.map(arn => {
58+
**/
59+
private def getCredentials(chosenRegion: String, roleArn: Option[String]) = {
60+
roleArn.map(arn => {
4461
val stsClient = AWSSecurityTokenServiceClientBuilder
4562
.standard()
4663
.withCredentials(new DefaultAWSCredentialsProviderChain)
@@ -59,24 +76,7 @@ private[dynamodb] trait DynamoConnector {
5976
)
6077
new AWSStaticCredentialsProvider(assumeCreds)
6178
}).orElse(Option(System.getProperty("aws.profile")).map(new ProfileCredentialsProvider(_)))
62-
.getOrElse(new DefaultAWSCredentialsProviderChain)
63-
}
64-
65-
def getDynamoDBClient(region: Option[String] = None, assumedArn: Option[String] = None): AmazonDynamoDB = {
66-
val chosenRegion = region.getOrElse(sys.env.getOrElse("aws.dynamodb.region", "us-east-1"))
67-
val credentials = getCredentials(chosenRegion, assumedArn)
68-
69-
Option(System.getProperty("aws.dynamodb.endpoint")).map(endpoint => {
70-
AmazonDynamoDBClientBuilder.standard()
71-
.withCredentials(credentials)
72-
.withEndpointConfiguration(new EndpointConfiguration(endpoint, chosenRegion))
73-
.build()
74-
}).getOrElse(
75-
AmazonDynamoDBClientBuilder.standard()
76-
.withCredentials(credentials)
77-
.withRegion(chosenRegion)
78-
.build()
79-
)
79+
.getOrElse(new DefaultAWSCredentialsProviderChain)
8080
}
8181

8282
val keySchema: KeySchema

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
4141
private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
4242
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
4343
private val region = parameters.get("region")
44-
private val assumedArn = parameters.get("assumedArn")
44+
private val roleArn = parameters.get("roleArn")
4545

4646
override val (keySchema, readLimit, writeLimit, itemLimit, totalSizeInBytes) = {
47-
val table = getDynamoDB(region, assumedArn).getTable(tableName)
47+
val table = getDynamoDB(region, roleArn).getTable(tableName)
4848
val desc = table.describe()
4949

5050
// Key schema.
@@ -95,7 +95,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
9595
scanSpec.withExpressionSpec(xspec.buildForScan())
9696
}
9797

98-
getDynamoDB(region, assumedArn).getTable(tableName).scan(scanSpec)
98+
getDynamoDB(region, roleArn).getTable(tableName).scan(scanSpec)
9999
}
100100

101101
override def putItems(schema: StructType, batchSize: Int)(items: Iterator[Row]): Unit = {
@@ -110,7 +110,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
110110
})
111111

112112
val rateLimiter = RateLimiter.create(writeLimit max 1)
113-
val client = getDynamoDB(region, assumedArn)
113+
val client = getDynamoDB(region, roleArn)
114114

115115
// For each batch.
116116
items.grouped(batchSize).foreach(itemBatch => {
@@ -155,7 +155,7 @@ private[dynamodb] class TableConnector(tableName: String, totalSegments: Int, pa
155155
})
156156

157157
val rateLimiter = RateLimiter.create(writeLimit max 1)
158-
val client = getDynamoDB(region, assumedArn)
158+
val client = getDynamoDB(region, roleArn)
159159

160160
// For each item.
161161
items.foreach(row => {

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
3434
private val consistentRead = parameters.getOrElse("stronglyConsistentReads", "false").toBoolean
3535
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
3636
private val region = parameters.get("region")
37-
private val assumedArn = parameters.get("assumedArn")
37+
private val roleArn = parameters.get("roleArn")
3838

3939
override val (keySchema, readLimit, itemLimit, totalSizeInBytes) = {
40-
val table = getDynamoDB(region, assumedArn).getTable(tableName)
40+
val table = getDynamoDB(region, roleArn).getTable(tableName)
4141
val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get
4242

4343
// Key schema.
@@ -82,7 +82,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
8282
scanSpec.withExpressionSpec(xspec.buildForScan())
8383
}
8484

85-
getDynamoDB(region, assumedArn).getTable(tableName).getIndex(indexName).scan(scanSpec)
85+
getDynamoDB(region, roleArn).getTable(tableName).getIndex(indexName).scan(scanSpec)
8686
}
8787

8888
}

0 commit comments

Comments
 (0)