Skip to content

Commit 8a3742f

Browse files
authored
fix: expose credential provider to customer properly (#362)
1 parent 6efe36b commit 8a3742f

File tree

9 files changed

+90
-11
lines changed

9 files changed

+90
-11
lines changed

AWSClientRuntime/Sources/AWSClientConfiguration.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ClientRuntime
77

88
public protocol AWSRuntimeConfiguration {
9-
var credentialsProvider: AWSCredentialsProvider { get set }
9+
var credentialsProvider: CredentialsProvider { get set }
1010
var region: String? { get set }
1111
var signingRegion: String? {get set}
1212
var endpointResolver: EndpointResolver {get set}

AWSClientRuntime/Sources/Auth/AWSCredentialsProvider.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import AwsCommonRuntimeKit
66
import ClientRuntime
77

8-
public class AWSCredentialsProvider {
8+
public class AWSCredentialsProvider: CredentialsProvider {
99
let crtCredentialsProvider: CRTAWSCredentialsProvider
1010

1111
init(awsCredentialsProvider: CRTAWSCredentialsProvider) {
@@ -47,6 +47,12 @@ public class AWSCredentialsProvider {
4747
return AWSCredentialsProvider(awsCredentialsProvider: credsProvider)
4848
}
4949

50+
public static func fromCustom(_ credentialsProvider: CredentialsProvider) throws -> AWSCredentialsProvider {
51+
let crtCredentialsProviderWrapper = CredentialsProviderCRTAdapter(credentialsProvider: credentialsProvider)
52+
let credsProvider = try CRTAWSCredentialsProvider(fromProvider: crtCredentialsProviderWrapper)
53+
return AWSCredentialsProvider(awsCredentialsProvider: credsProvider)
54+
}
55+
5056
public func getCredentials() throws -> AWSCredentials {
5157
let credentials = crtCredentialsProvider.getCredentials()
5258
let result = try credentials.get()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
public protocol CredentialsProvider {
9+
/// Resolves `AWSCredentials` through custom means
10+
func getCredentials() throws -> AWSCredentials
11+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
import AwsCommonRuntimeKit
9+
import ClientRuntime
10+
11+
struct CredentialsProviderCRTAdapter: CRTCredentialsProvider {
12+
var allocator: Allocator
13+
let credentialsProvider: CredentialsProvider
14+
let logger: SwiftLogger
15+
init(credentialsProvider: CredentialsProvider) {
16+
self.credentialsProvider = credentialsProvider
17+
self.logger = SwiftLogger(label: "CustomCredentialProvider")
18+
self.allocator = defaultAllocator
19+
}
20+
21+
func getCredentials(credentialCallbackData: CRTCredentialsProviderCallbackData) {
22+
do {
23+
let credentials = try credentialsProvider.getCredentials()
24+
let emptyError = AWSError(errorCode: 0)
25+
let crtCredentials = credentials.toCRTType()
26+
credentialCallbackData.onCredentialsResolved?(crtCredentials, CRTError.crtError(emptyError))
27+
} catch let err {
28+
logger.error("An error occurred with retrieving credentials from your custom credentials provider. Error: \(err)")
29+
30+
credentialCallbackData.onCredentialsResolved?(nil, CRTError.crtError(AWSError(errorCode: -1)))
31+
}
32+
33+
}
34+
}

AWSClientRuntime/Sources/HttpContextBuilder+Extension.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ extension HttpContextBuilder {
3939
}
4040

4141
@discardableResult
42-
public func withCredentialsProvider(value: AWSCredentialsProvider) -> HttpContextBuilder {
43-
self.attributes.set(key: AttributeKey<AWSCredentialsProvider>(name: "AWSCredentialsProvider"), value: value)
42+
public func withCredentialsProvider(value: CredentialsProvider) -> HttpContextBuilder {
43+
self.attributes.set(key: AttributeKey<CredentialsProvider>(name: "CredentialsProvider"), value: value)
4444
return self
4545
}
4646

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
import ClientRuntime
9+
import SmithyTestUtil
10+
import XCTest
11+
@testable import AWSClientRuntime
12+
13+
class AWSCredentialProviderTests: XCTestCase {
14+
15+
func testYouCanUseCustomCredentialsProvider() throws {
16+
let awsCredsProvider = try AWSCredentialsProvider.fromCustom(MyCustomCredentialsProvider())
17+
let credentials = try awsCredsProvider.getCredentials()
18+
XCTAssertEqual(credentials.accessKey, "MYACCESSKEY")
19+
XCTAssertEqual(credentials.secret, "sekrit")
20+
}
21+
}
22+
23+
struct MyCustomCredentialsProvider: CredentialsProvider {
24+
func getCredentials() throws -> AWSCredentials {
25+
return AWSCredentials(accessKey: "MYACCESSKEY", secret: "sekrit", expirationTimeout: 30)
26+
}
27+
}

codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSClientRuntimeTypes.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ object AWSClientRuntimeTypes {
3939
val RetrierMiddleware = runtimeSymbol("RetrierMiddleware")
4040
val EndpointResolverMiddleware = runtimeSymbol("EndpointResolverMiddleware")
4141
val EndpointResolver = runtimeSymbol("EndpointResolver")
42-
val CredentialsProvider = runtimeSymbol("AWSCredentialsProvider")
42+
val CredentialsProvider = runtimeSymbol("CredentialsProvider")
43+
val AWSCredentialsProvider = runtimeSymbol("AWSCredentialsProvider")
4344
val AWSClientConfiguration = runtimeSymbol("AWSClientConfiguration")
4445
val AWSEndpoint = runtimeSymbol("AWSEndpoint")
4546
val Partition = runtimeSymbol("Partition")

codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSServiceConfig.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ class AWSServiceConfig(writer: SwiftWriter, serviceName: String) : ServiceConfig
3737
writer.write("self.signingRegion = signingRegion ?? defaultRegion")
3838
writer.write("self.endpointResolver = endpointResolver ?? DefaultEndpointResolver()")
3939
writer.openBlock("if let credProvider = credentialsProvider {", "} else {") {
40-
writer.write("self.credentialsProvider = credProvider")
40+
writer.write("self.credentialsProvider = try \$N.fromCustom(credProvider)", AWSClientRuntimeTypes.Core.AWSCredentialsProvider)
4141
}
42-
writer.indent().write("self.credentialsProvider = try \$N.fromChain()", AWSClientRuntimeTypes.Core.CredentialsProvider)
42+
writer.indent().write("self.credentialsProvider = try \$N.fromChain()", AWSClientRuntimeTypes.Core.AWSCredentialsProvider)
4343
writer.dedent().write("}")
4444
val runtimeTimeConfigFields = sdkRuntimeConfigProperties()
4545
runtimeTimeConfigFields.forEach {

codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/RestJsonProtocolGeneratorTests.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ class RestJsonProtocolGeneratorTests {
116116
public var logger: ClientRuntime.LogAgent
117117
public var retrier: ClientRuntime.Retrier
118118
119-
public var credentialsProvider: AWSClientRuntime.AWSCredentialsProvider
119+
public var credentialsProvider: AWSClientRuntime.CredentialsProvider
120120
public var endpointResolver: AWSClientRuntime.EndpointResolver
121121
public var region: Swift.String?
122122
public var regionResolver: AWSClientRuntime.RegionResolver
123123
public var signingRegion: Swift.String?
124124
125125
public init(
126-
credentialsProvider: AWSClientRuntime.AWSCredentialsProvider? = nil,
126+
credentialsProvider: AWSClientRuntime.CredentialsProvider? = nil,
127127
endpointResolver: AWSClientRuntime.EndpointResolver? = nil,
128128
region: Swift.String? = nil,
129129
regionResolver: AWSClientRuntime.RegionResolver? = nil,
@@ -136,7 +136,7 @@ class RestJsonProtocolGeneratorTests {
136136
self.signingRegion = signingRegion ?? defaultRegion
137137
self.endpointResolver = endpointResolver ?? DefaultEndpointResolver()
138138
if let credProvider = credentialsProvider {
139-
self.credentialsProvider = credProvider
139+
self.credentialsProvider = try AWSClientRuntime.AWSCredentialsProvider.fromCustom(credProvider)
140140
} else {
141141
self.credentialsProvider = try AWSClientRuntime.AWSCredentialsProvider.fromChain()
142142
}
@@ -151,7 +151,7 @@ class RestJsonProtocolGeneratorTests {
151151
}
152152
153153
public convenience init(
154-
credentialsProvider: AWSClientRuntime.AWSCredentialsProvider? = nil,
154+
credentialsProvider: AWSClientRuntime.CredentialsProvider? = nil,
155155
endpointResolver: AWSClientRuntime.EndpointResolver? = nil,
156156
region: Swift.String? = nil,
157157
regionResolver: AWSClientRuntime.RegionResolver? = nil,

0 commit comments

Comments
 (0)