Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit a390bca

Browse files
committed
Allows giving providerclassname, used to create awsCredentialsProvider object
1 parent ba7e8c8 commit a390bca

File tree

4 files changed

+29
-17
lines changed

4 files changed

+29
-17
lines changed

src/main/scala/com/audienceproject/spark/dynamodb/connector/DynamoConnector.scala

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
package com.audienceproject.spark.dynamodb.connector
2222

2323
import com.amazonaws.auth.profile.ProfileCredentialsProvider
24-
import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
24+
import com.amazonaws.auth.{AWSCredentialsProvider, AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
2525
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
2626
import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
2727
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBAsync, AmazonDynamoDBAsyncClientBuilder, AmazonDynamoDBClientBuilder}
@@ -33,14 +33,16 @@ private[dynamodb] trait DynamoConnector {
3333

3434
@transient private lazy val properties = sys.props
3535

36-
def getDynamoDB(region: Option[String] = None, roleArn: Option[String] = None): DynamoDB = {
37-
val client: AmazonDynamoDB = getDynamoDBClient(region, roleArn)
36+
def getDynamoDB(region: Option[String] = None, roleArn: Option[String] = None, providerClassName: Option[String] = None): DynamoDB = {
37+
val client: AmazonDynamoDB = getDynamoDBClient(region, roleArn, providerClassName)
3838
new DynamoDB(client)
3939
}
4040

41-
private def getDynamoDBClient(region: Option[String] = None, roleArn: Option[String] = None): AmazonDynamoDB = {
41+
private def getDynamoDBClient(region: Option[String] = None,
42+
roleArn: Option[String] = None,
43+
providerClassName: Option[String]): AmazonDynamoDB = {
4244
val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1"))
43-
val credentials = getCredentials(chosenRegion, roleArn)
45+
val credentials = getCredentials(chosenRegion, roleArn, providerClassName)
4446

4547
properties.get("aws.dynamodb.endpoint").map(endpoint => {
4648
AmazonDynamoDBClientBuilder.standard()
@@ -55,9 +57,11 @@ private[dynamodb] trait DynamoConnector {
5557
)
5658
}
5759

58-
def getDynamoDBAsyncClient(region: Option[String] = None, roleArn: Option[String] = None): AmazonDynamoDBAsync = {
60+
def getDynamoDBAsyncClient(region: Option[String] = None,
61+
roleArn: Option[String] = None,
62+
providerClassName: Option[String] = None): AmazonDynamoDBAsync = {
5963
val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1"))
60-
val credentials = getCredentials(chosenRegion, roleArn)
64+
val credentials = getCredentials(chosenRegion, roleArn, providerClassName)
6165

6266
properties.get("aws.dynamodb.endpoint").map(endpoint => {
6367
AmazonDynamoDBAsyncClientBuilder.standard()
@@ -73,10 +77,15 @@ private[dynamodb] trait DynamoConnector {
7377
}
7478

7579
/**
76-
* Get credentials from a passed in arn or from profile or return the default credential provider
77-
**/
78-
private def getCredentials(chosenRegion: String, roleArn: Option[String]) = {
79-
roleArn.map(arn => {
80+
* Get credentials from an instantiated object of the class name given
81+
* or a passed in arn
82+
* or from profile
83+
* or return the default credential provider
84+
**/
85+
private def getCredentials(chosenRegion: String, roleArn: Option[String], providerClassName: Option[String]) = {
86+
providerClassName.map(providerClass => {
87+
Class.forName(providerClass).newInstance.asInstanceOf[AWSCredentialsProvider]
88+
}).orElse(roleArn.map(arn => {
8089
val stsClient = properties.get("aws.sts.endpoint").map(endpoint => {
8190
AWSSecurityTokenServiceClientBuilder
8291
.standard()
@@ -103,7 +112,7 @@ private[dynamodb] trait DynamoConnector {
103112
stsCredentials.getSessionToken
104113
)
105114
new AWSStaticCredentialsProvider(assumeCreds)
106-
}).orElse(properties.get("aws.profile").map(new ProfileCredentialsProvider(_)))
115+
})).orElse(properties.get("aws.profile").map(new ProfileCredentialsProvider(_)))
107116
.getOrElse(new DefaultAWSCredentialsProviderChain)
108117
}
109118

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableConnector.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
3939
private val filterPushdown = parameters.getOrElse("filterpushdown", "true").toBoolean
4040
private val region = parameters.get("region")
4141
private val roleArn = parameters.get("rolearn")
42+
private val providerClassName = parameters.get("providerclassname")
4243

4344
override val filterPushdownEnabled: Boolean = filterPushdown
4445

4546
override val (keySchema, readLimit, writeLimit, itemLimit, totalSegments) = {
46-
val table = getDynamoDB(region, roleArn).getTable(tableName)
47+
val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName)
4748
val desc = table.describe()
4849

4950
// Key schema.
@@ -106,7 +107,7 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
106107
scanSpec.withExpressionSpec(xspec.buildForScan())
107108
}
108109

109-
getDynamoDB(region, roleArn).getTable(tableName).scan(scanSpec)
110+
getDynamoDB(region, roleArn, providerClassName).getTable(tableName).scan(scanSpec)
110111
}
111112

112113
override def putItems(columnSchema: ColumnSchema, items: Seq[InternalRow])

src/main/scala/com/audienceproject/spark/dynamodb/connector/TableIndexConnector.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
3535
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
3636
private val region = parameters.get("region")
3737
private val roleArn = parameters.get("roleArn")
38+
private val providerClassName = parameters.get("providerclassname")
3839

3940
override val filterPushdownEnabled: Boolean = filterPushdown
4041

4142
override val (keySchema, readLimit, itemLimit, totalSegments) = {
42-
val table = getDynamoDB(region, roleArn).getTable(tableName)
43+
val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName)
4344
val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get
4445

4546
// Key schema.
@@ -96,7 +97,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
9697
scanSpec.withExpressionSpec(xspec.buildForScan())
9798
}
9899

99-
getDynamoDB(region, roleArn).getTable(tableName).getIndex(indexName).scan(scanSpec)
100+
getDynamoDB(region, roleArn, providerClassName).getTable(tableName).getIndex(indexName).scan(scanSpec)
100101
}
101102

102103
}

src/main/scala/com/audienceproject/spark/dynamodb/datasource/DynamoWriterFactory.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ class DynamoWriterFactory(connector: TableConnector,
3636

3737
private val region = parameters.get("region")
3838
private val roleArn = parameters.get("rolearn")
39+
private val providerClassName = parameters.get("providerclassname")
3940

4041
override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = {
4142
val columnSchema = new ColumnSchema(connector.keySchema, schema)
42-
val client = connector.getDynamoDB(region, roleArn)
43+
val client = connector.getDynamoDB(region, roleArn, providerClassName)
4344
if (update) {
4445
assert(!delete, "Please provide exactly one of 'update' or 'delete' options.")
4546
new DynamoUpdateWriter(columnSchema, connector, client)

0 commit comments

Comments
 (0)