21
21
package com .audienceproject .spark .dynamodb .connector
22
22
23
23
import com .amazonaws .auth .profile .ProfileCredentialsProvider
24
- import com .amazonaws .auth .{AWSStaticCredentialsProvider , BasicSessionCredentials , DefaultAWSCredentialsProviderChain }
24
+ import com .amazonaws .auth .{AWSCredentialsProvider , AWSStaticCredentialsProvider , BasicSessionCredentials , DefaultAWSCredentialsProviderChain }
25
25
import com .amazonaws .client .builder .AwsClientBuilder .EndpointConfiguration
26
26
import com .amazonaws .services .dynamodbv2 .document .{DynamoDB , ItemCollection , ScanOutcome }
27
27
import com .amazonaws .services .dynamodbv2 .{AmazonDynamoDB , AmazonDynamoDBAsync , AmazonDynamoDBAsyncClientBuilder , AmazonDynamoDBClientBuilder }
@@ -33,14 +33,16 @@ private[dynamodb] trait DynamoConnector {
33
33
34
34
@ transient private lazy val properties = sys.props
35
35
36
- def getDynamoDB (region : Option [String ] = None , roleArn : Option [String ] = None ): DynamoDB = {
37
- val client : AmazonDynamoDB = getDynamoDBClient(region, roleArn)
36
+ def getDynamoDB (region : Option [String ] = None , roleArn : Option [String ] = None , providerClassName : Option [ String ] = None ): DynamoDB = {
37
+ val client : AmazonDynamoDB = getDynamoDBClient(region, roleArn, providerClassName )
38
38
new DynamoDB (client)
39
39
}
40
40
41
- private def getDynamoDBClient (region : Option [String ] = None , roleArn : Option [String ] = None ): AmazonDynamoDB = {
41
+ private def getDynamoDBClient (region : Option [String ] = None ,
42
+ roleArn : Option [String ] = None ,
43
+ providerClassName : Option [String ]): AmazonDynamoDB = {
42
44
val chosenRegion = region.getOrElse(properties.getOrElse(" aws.dynamodb.region" , " us-east-1" ))
43
- val credentials = getCredentials(chosenRegion, roleArn)
45
+ val credentials = getCredentials(chosenRegion, roleArn, providerClassName )
44
46
45
47
properties.get(" aws.dynamodb.endpoint" ).map(endpoint => {
46
48
AmazonDynamoDBClientBuilder .standard()
@@ -55,9 +57,11 @@ private[dynamodb] trait DynamoConnector {
55
57
)
56
58
}
57
59
58
- def getDynamoDBAsyncClient (region : Option [String ] = None , roleArn : Option [String ] = None ): AmazonDynamoDBAsync = {
60
+ def getDynamoDBAsyncClient (region : Option [String ] = None ,
61
+ roleArn : Option [String ] = None ,
62
+ providerClassName : Option [String ] = None ): AmazonDynamoDBAsync = {
59
63
val chosenRegion = region.getOrElse(properties.getOrElse(" aws.dynamodb.region" , " us-east-1" ))
60
- val credentials = getCredentials(chosenRegion, roleArn)
64
+ val credentials = getCredentials(chosenRegion, roleArn, providerClassName )
61
65
62
66
properties.get(" aws.dynamodb.endpoint" ).map(endpoint => {
63
67
AmazonDynamoDBAsyncClientBuilder .standard()
@@ -73,10 +77,15 @@ private[dynamodb] trait DynamoConnector {
73
77
}
74
78
75
79
/**
76
- * Get credentials from a passed in arn or from profile or return the default credential provider
77
- **/
78
- private def getCredentials (chosenRegion : String , roleArn : Option [String ]) = {
79
- roleArn.map(arn => {
80
+ * Get credentials from an instantiated object of the class name given
81
+ * or a passed in arn
82
+ * or from profile
83
+ * or return the default credential provider
84
+ **/
85
+ private def getCredentials (chosenRegion : String , roleArn : Option [String ], providerClassName : Option [String ]) = {
86
+ providerClassName.map(providerClass => {
87
+ Class .forName(providerClass).newInstance.asInstanceOf [AWSCredentialsProvider ]
88
+ }).orElse(roleArn.map(arn => {
80
89
val stsClient = properties.get(" aws.sts.endpoint" ).map(endpoint => {
81
90
AWSSecurityTokenServiceClientBuilder
82
91
.standard()
@@ -103,7 +112,7 @@ private[dynamodb] trait DynamoConnector {
103
112
stsCredentials.getSessionToken
104
113
)
105
114
new AWSStaticCredentialsProvider (assumeCreds)
106
- }).orElse(properties.get(" aws.profile" ).map(new ProfileCredentialsProvider (_)))
115
+ })) .orElse(properties.get(" aws.profile" ).map(new ProfileCredentialsProvider (_)))
107
116
.getOrElse(new DefaultAWSCredentialsProviderChain )
108
117
}
109
118
0 commit comments