Skip to content

Commit f386a64

Browse files
committed
updated RefreshUserPoolTokens to use GetTokensFromRefreshToken API to enable refresh token rotation, also updated test mock clients and added unit tests
1 parent 5866f82 commit f386a64

File tree

6 files changed

+113
-28
lines changed

6 files changed

+113
-28
lines changed

AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Actions/RefreshAuthorizationSession/UserPool/RefreshUserPoolTokens.swift

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,34 @@ struct RefreshUserPoolTokens: Action {
2828
return
2929
}
3030

31-
let authEnv = try environment.authEnvironment()
3231
let config = environment.userPoolConfiguration
3332
let client = try? environment.cognitoUserPoolFactory()
3433
let existingTokens = existingSignedIndata.cognitoUserPoolTokens
3534

3635
let deviceMetadata = await DeviceMetadataHelper.getDeviceMetadata(
3736
for: existingSignedIndata.username,
3837
with: environment)
39-
40-
let asfDeviceId = try await CognitoUserPoolASF.asfDeviceID(
41-
for: existingSignedIndata.username,
42-
credentialStoreClient: authEnv.credentialsClient)
43-
44-
let input = await InitiateAuthInput.refreshAuthInput(
45-
username: existingSignedIndata.username,
46-
refreshToken: existingTokens.refreshToken,
38+
39+
let deviceKey: String? = {
40+
if case .metadata(let data) = deviceMetadata {
41+
return data.deviceKey
42+
}
43+
return nil
44+
}()
45+
46+
let input = GetTokensFromRefreshTokenInput(
47+
clientId: config.clientId,
4748
clientMetadata: [:],
48-
asfDeviceId: asfDeviceId,
49-
deviceMetadata: deviceMetadata,
50-
environment: environment)
49+
clientSecret: config.clientSecret,
50+
deviceKey: deviceKey,
51+
refreshToken: existingTokens.refreshToken
52+
)
5153

52-
logVerbose("\(#fileID) Starting initiate auth refresh token", environment: environment)
54+
logVerbose("\(#fileID) Starting get tokens from refresh token", environment: environment)
5355

54-
let response = try await client?.initiateAuth(input: input)
56+
let response = try await client?.getTokensFromRefreshToken(input: input)
5557

56-
logVerbose("\(#fileID) Initiate auth response received", environment: environment)
58+
logVerbose("\(#fileID) Get tokens from refresh token response received", environment: environment)
5759

5860
guard let authenticationResult = response?.authenticationResult,
5961
let idToken = authenticationResult.idToken,
@@ -69,8 +71,7 @@ struct RefreshUserPoolTokens: Action {
6971
let userPoolTokens = AWSCognitoUserPoolTokens(
7072
idToken: idToken,
7173
accessToken: accessToken,
72-
refreshToken: existingTokens.refreshToken,
73-
expiresIn: authenticationResult.expiresIn
74+
refreshToken: authenticationResult.refreshToken ?? existingTokens.refreshToken
7475
)
7576
let signedInData = SignedInData(
7677
signedInDate: existingSignedIndata.signedInDate,
@@ -96,7 +97,7 @@ struct RefreshUserPoolTokens: Action {
9697
await dispatcher.send(event)
9798
}
9899

99-
logVerbose("\(#fileID) Initiate auth complete", environment: environment)
100+
logVerbose("\(#fileID) Get tokens from refresh token complete", environment: environment)
100101
}
101102
}
102103

AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Service/CognitoUserPoolBehavior.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ protocol CognitoUserPoolBehavior {
2929
/// Throws RevokeTokenOutputError
3030
func revokeToken(input: RevokeTokenInput) async throws -> RevokeTokenOutput
3131

32+
/// Throws GetTokensFromRefreshTokenOutputError
33+
func getTokensFromRefreshToken(input: GetTokensFromRefreshTokenInput) async throws -> GetTokensFromRefreshTokenOutput
34+
3235
// MARK: - User Attribute API's
3336

3437
/// Throws GetUserAttributeVerificationCodeOutputError

AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/FetchAuthSession/FetchUserPoolTokens/RefreshUserPoolTokensTests.swift

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class RefreshUserPoolTokensTests: XCTestCase {
4545
let expectation = expectation(description: "refreshUserPoolTokens")
4646
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
4747
MockIdentityProvider(
48-
mockInitiateAuthResponse: { _ in
49-
return InitiateAuthOutput()
48+
mockGetTokensFromRefreshTokenResponse: { _ in
49+
return GetTokensFromRefreshTokenOutput()
5050
}
5151
)
5252
}
@@ -77,8 +77,8 @@ class RefreshUserPoolTokensTests: XCTestCase {
7777
let expectation = expectation(description: "refreshUserPoolTokens")
7878
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
7979
MockIdentityProvider(
80-
mockInitiateAuthResponse: { _ in
81-
return InitiateAuthOutput(
80+
mockGetTokensFromRefreshTokenResponse: { _ in
81+
return GetTokensFromRefreshTokenOutput(
8282
authenticationResult: .init(
8383
accessToken: "accessTokenNew",
8484
expiresIn: 100,
@@ -114,7 +114,7 @@ class RefreshUserPoolTokensTests: XCTestCase {
114114

115115
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
116116
MockIdentityProvider(
117-
mockInitiateAuthResponse: { _ in
117+
mockGetTokensFromRefreshTokenResponse: { _ in
118118
throw testError
119119
}
120120
)
@@ -144,4 +144,74 @@ class RefreshUserPoolTokensTests: XCTestCase {
144144
)
145145
}
146146

147+
func testRefreshTokenRotation() async {
148+
149+
let expectation = expectation(description: "refreshTokenRotation")
150+
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
151+
MockIdentityProvider(
152+
mockGetTokensFromRefreshTokenResponse: { _ in
153+
return GetTokensFromRefreshTokenOutput(
154+
authenticationResult: .init(
155+
accessToken: "accessTokenNew",
156+
expiresIn: 100,
157+
idToken: "idTokenNew",
158+
refreshToken: "refreshTokenRotated"))
159+
}
160+
)
161+
}
162+
163+
let action = RefreshUserPoolTokens(existingSignedIndata: .testData)
164+
165+
await action.execute(withDispatcher: MockDispatcher { event in
166+
167+
if let userPoolEvent = event as? RefreshSessionEvent,
168+
case let .refreshIdentityInfo(signedInData, _) = userPoolEvent.eventType {
169+
XCTAssertEqual(signedInData.cognitoUserPoolTokens.refreshToken, "refreshTokenRotated")
170+
expectation.fulfill()
171+
}
172+
}, environment: Defaults.makeDefaultAuthEnvironment(
173+
userPoolFactory: identityProviderFactory)
174+
)
175+
176+
await fulfillment(
177+
of: [expectation],
178+
timeout: 0.1
179+
)
180+
}
181+
182+
func testRefreshTokenNoRotation() async {
183+
184+
let expectation = expectation(description: "refreshTokenNoRotation")
185+
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
186+
MockIdentityProvider(
187+
mockGetTokensFromRefreshTokenResponse: { _ in
188+
return GetTokensFromRefreshTokenOutput(
189+
authenticationResult: .init(
190+
accessToken: "accessTokenNew",
191+
expiresIn: 100,
192+
idToken: "idTokenNew",
193+
refreshToken: nil))
194+
}
195+
)
196+
}
197+
198+
let action = RefreshUserPoolTokens(existingSignedIndata: .testData)
199+
200+
await action.execute(withDispatcher: MockDispatcher { event in
201+
202+
if let userPoolEvent = event as? RefreshSessionEvent,
203+
case let .refreshIdentityInfo(signedInData, _) = userPoolEvent.eventType {
204+
XCTAssertEqual(signedInData.cognitoUserPoolTokens.refreshToken, "refreshToken")
205+
expectation.fulfill()
206+
}
207+
}, environment: Defaults.makeDefaultAuthEnvironment(
208+
userPoolFactory: identityProviderFactory)
209+
)
210+
211+
await fulfillment(
212+
of: [expectation],
213+
timeout: 0.1
214+
)
215+
}
216+
147217
}

AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/Support/MockIdentityProvider.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
2020
typealias MockInitiateAuthResponse = (InitiateAuthInput) async throws
2121
-> InitiateAuthOutput
2222

23+
typealias MockGetTokensFromRefreshTokenResponse = (GetTokensFromRefreshTokenInput) async throws
24+
-> GetTokensFromRefreshTokenOutput
25+
2326
typealias MockConfirmSignUpResponse = (ConfirmSignUpInput) async throws
2427
-> ConfirmSignUpOutput
2528

@@ -88,6 +91,7 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
8891
let mockSignUpResponse: MockSignUpResponse?
8992
let mockRevokeTokenResponse: MockRevokeTokenResponse?
9093
let mockInitiateAuthResponse: MockInitiateAuthResponse?
94+
let mockGetTokensFromRefreshTokenResponse: MockGetTokensFromRefreshTokenResponse?
9195
let mockGlobalSignOutResponse: MockGlobalSignOutResponse?
9296
let mockConfirmSignUpResponse: MockConfirmSignUpResponse?
9397
let mockRespondToAuthChallengeResponse: MockRespondToAuthChallengeResponse?
@@ -116,6 +120,7 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
116120
mockSignUpResponse: MockSignUpResponse? = nil,
117121
mockRevokeTokenResponse: MockRevokeTokenResponse? = nil,
118122
mockInitiateAuthResponse: MockInitiateAuthResponse? = nil,
123+
mockGetTokensFromRefreshTokenResponse: MockGetTokensFromRefreshTokenResponse? = nil,
119124
mockGlobalSignOutResponse: MockGlobalSignOutResponse? = nil,
120125
mockConfirmSignUpResponse: MockConfirmSignUpResponse? = nil,
121126
mockRespondToAuthChallengeResponse: MockRespondToAuthChallengeResponse? = nil,
@@ -139,6 +144,7 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
139144
self.mockSignUpResponse = mockSignUpResponse
140145
self.mockRevokeTokenResponse = mockRevokeTokenResponse
141146
self.mockInitiateAuthResponse = mockInitiateAuthResponse
147+
self.mockGetTokensFromRefreshTokenResponse = mockGetTokensFromRefreshTokenResponse
142148
self.mockGlobalSignOutResponse = mockGlobalSignOutResponse
143149
self.mockConfirmSignUpResponse = mockConfirmSignUpResponse
144150
self.mockRespondToAuthChallengeResponse = mockRespondToAuthChallengeResponse
@@ -192,6 +198,11 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
192198
return try await mockRevokeTokenResponse!(input)
193199
}
194200

201+
/// Throws GetTokensFromRefreshTokenOutputError
202+
func getTokensFromRefreshToken(input: GetTokensFromRefreshTokenInput) async throws -> GetTokensFromRefreshTokenOutput {
203+
return try await mockGetTokensFromRefreshTokenResponse!(input)
204+
}
205+
195206
func getUserAttributeVerificationCode(input: GetUserAttributeVerificationCodeInput) async throws -> GetUserAttributeVerificationCodeOutput {
196207
return try await mockGetUserAttributeVerificationCodeOutput!(input)
197208
}

Package.resolved

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ let platforms: [SupportedPlatform] = [
99
.watchOS(.v9)
1010
]
1111
let dependencies: [Package.Dependency] = [
12-
.package(url: "https://github.com/awslabs/aws-sdk-swift", exact: "1.5.14"),
12+
.package(url: "https://github.com/awslabs/aws-sdk-swift", exact: "1.5.18"),
1313
.package(url: "https://github.com/stephencelis/SQLite.swift.git", exact: "0.15.3"),
1414
.package(url: "https://github.com/mattgallagher/CwlPreconditionTesting.git", from: "2.1.0"),
1515
.package(url: "https://github.com/aws-amplify/amplify-swift-utils-notifications.git", from: "1.1.0")

0 commit comments

Comments
 (0)