Skip to content

Commit 905e9e8

Browse files
authored
feat(auth): Refresh Token Rotation (#4050)
* removed usage of old refresh token, updated test case to properly check for invalidTokens error in case the API returns a null token * updated authfetchsigninsession tests to use GetTokenFromRefreshToken in mock clients * updated further tests to use new API as well as auth hub event handler tests * updated RefreshUserPoolTokens to use GetTokensFromRefreshToken API to enable refresh token rotation, also updated test mock clients and added unit tests * removed refreshAuth function from initiateAuthInput
1 parent 9a024d9 commit 905e9e8

File tree

9 files changed

+194
-112
lines changed

9 files changed

+194
-112
lines changed

AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Actions/RefreshAuthorizationSession/UserPool/RefreshUserPoolTokens.swift

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
// SPDX-License-Identifier: Apache-2.0
66
//
77

8-
import Amplify
9-
import AWSPluginsCore
108
import AWSCognitoIdentityProvider
11-
import Foundation
9+
import AWSPluginsCore
10+
import Amplify
1211
import ClientRuntime
12+
import Foundation
1313

1414
struct RefreshUserPoolTokens: Action {
1515

@@ -28,7 +28,6 @@ struct RefreshUserPoolTokens: Action {
2828
return
2929
}
3030

31-
let authEnv = try environment.authEnvironment()
3231
let config = environment.userPoolConfiguration
3332
let client = try? environment.cognitoUserPoolFactory()
3433
let existingTokens = existingSignedIndata.cognitoUserPoolTokens
@@ -37,29 +36,35 @@ struct RefreshUserPoolTokens: Action {
3736
for: existingSignedIndata.username,
3837
with: environment)
3938

40-
let asfDeviceId = try await CognitoUserPoolASF.asfDeviceID(
41-
for: existingSignedIndata.username,
42-
credentialStoreClient: authEnv.credentialsClient)
39+
let deviceKey: String? = {
40+
if case .metadata(let data) = deviceMetadata {
41+
return data.deviceKey
42+
}
43+
return nil
44+
}()
4345

44-
let input = await InitiateAuthInput.refreshAuthInput(
45-
username: existingSignedIndata.username,
46-
refreshToken: existingTokens.refreshToken,
46+
let input = GetTokensFromRefreshTokenInput(
47+
clientId: config.clientId,
4748
clientMetadata: [:],
48-
asfDeviceId: asfDeviceId,
49-
deviceMetadata: deviceMetadata,
50-
environment: environment)
49+
clientSecret: config.clientSecret,
50+
deviceKey: deviceKey,
51+
refreshToken: existingTokens.refreshToken
52+
)
5153

52-
logVerbose("\(#fileID) Starting initiate auth refresh token", environment: environment)
54+
logVerbose(
55+
"\(#fileID) Starting get tokens from refresh token", environment: environment)
5356

54-
let response = try await client?.initiateAuth(input: input)
57+
let response = try await client?.getTokensFromRefreshToken(input: input)
5558

56-
logVerbose("\(#fileID) Initiate auth response received", environment: environment)
59+
logVerbose(
60+
"\(#fileID) Get tokens from refresh token response received",
61+
environment: environment)
5762

5863
guard let authenticationResult = response?.authenticationResult,
59-
let idToken = authenticationResult.idToken,
60-
let accessToken = authenticationResult.accessToken
64+
let idToken = authenticationResult.idToken,
65+
let accessToken = authenticationResult.accessToken,
66+
let refreshToken = authenticationResult.refreshToken
6167
else {
62-
6368
let event = RefreshSessionEvent(eventType: .throwError(.invalidTokens))
6469
await dispatcher.send(event)
6570
logVerbose("\(#fileID) Sending event \(event.type)", environment: environment)
@@ -69,9 +74,9 @@ struct RefreshUserPoolTokens: Action {
6974
let userPoolTokens = AWSCognitoUserPoolTokens(
7075
idToken: idToken,
7176
accessToken: accessToken,
72-
refreshToken: existingTokens.refreshToken,
73-
expiresIn: authenticationResult.expiresIn
77+
refreshToken: refreshToken
7478
)
79+
7580
let signedInData = SignedInData(
7681
signedInDate: existingSignedIndata.signedInDate,
7782
signInMethod: existingSignedIndata.signInMethod,
@@ -96,13 +101,14 @@ struct RefreshUserPoolTokens: Action {
96101
await dispatcher.send(event)
97102
}
98103

99-
logVerbose("\(#fileID) Initiate auth complete", environment: environment)
104+
logVerbose("\(#fileID) Get tokens from refresh token complete", environment: environment)
100105
}
101106
}
102107

103108
extension RefreshUserPoolTokens: DefaultLogger {
104109
public static var log: Logger {
105-
Amplify.Logging.logger(forCategory: CategoryType.auth.displayName, forNamespace: String(describing: self))
110+
Amplify.Logging.logger(
111+
forCategory: CategoryType.auth.displayName, forNamespace: String(describing: self))
106112
}
107113

108114
public var log: Logger {
@@ -114,7 +120,7 @@ extension RefreshUserPoolTokens: CustomDebugDictionaryConvertible {
114120
var debugDictionary: [String: Any] {
115121
[
116122
"identifier": identifier,
117-
"existingSignedInData": existingSignedIndata
123+
"existingSignedInData": existingSignedIndata,
118124
]
119125
}
120126
}

AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Service/CognitoUserPoolBehavior.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ protocol CognitoUserPoolBehavior {
2929
/// Throws RevokeTokenOutputError
3030
func revokeToken(input: RevokeTokenInput) async throws -> RevokeTokenOutput
3131

32+
/// Throws GetTokensFromRefreshTokenOutputError
33+
func getTokensFromRefreshToken(input: GetTokensFromRefreshTokenInput) async throws -> GetTokensFromRefreshTokenOutput
34+
3235
// MARK: - User Attribute API's
3336

3437
/// Throws GetUserAttributeVerificationCodeOutputError

AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Utils/InitiateAuthInput+Amplify.swift

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,27 +93,6 @@ extension InitiateAuthInput {
9393
environment: environment)
9494
}
9595

96-
static func refreshAuthInput(username: String,
97-
refreshToken: String,
98-
clientMetadata: [String: String],
99-
asfDeviceId: String,
100-
deviceMetadata: DeviceMetadata,
101-
environment: UserPoolEnvironment) async -> InitiateAuthInput {
102-
103-
let authParameters = [
104-
"REFRESH_TOKEN": refreshToken
105-
]
106-
107-
return await buildInput(username: username,
108-
authFlowType: .refreshTokenAuth,
109-
authParameters: authParameters,
110-
clientMetadata: clientMetadata,
111-
asfDeviceId: asfDeviceId,
112-
deviceMetadata: deviceMetadata,
113-
environment: environment)
114-
115-
}
116-
11796
static func buildInput(username: String,
11897
authFlowType: CognitoIdentityProviderClientTypes.AuthFlowType,
11998
authParameters: [String: String],

AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ActionTests/FetchAuthSession/FetchUserPoolTokens/RefreshUserPoolTokensTests.swift

Lines changed: 125 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
// SPDX-License-Identifier: Apache-2.0
66
//
77

8-
import XCTest
9-
import Amplify
10-
import AWSPluginsCore
118
import AWSCognitoIdentityProvider
9+
import AWSPluginsCore
10+
import Amplify
11+
import XCTest
1212

1313
@testable import AWSCognitoAuthPlugin
1414

@@ -20,18 +20,19 @@ class RefreshUserPoolTokensTests: XCTestCase {
2020

2121
let action = RefreshUserPoolTokens(existingSignedIndata: .testData)
2222

23-
await action.execute(withDispatcher: MockDispatcher { event in
23+
await action.execute(
24+
withDispatcher: MockDispatcher { event in
2425

25-
guard let event = event as? RefreshSessionEvent else {
26-
return
27-
}
26+
guard let event = event as? RefreshSessionEvent else {
27+
return
28+
}
2829

29-
if case let .throwError(error) = event.eventType {
30-
XCTAssertNotNil(error)
31-
XCTAssertEqual(error, .noUserPool)
32-
expectation.fulfill()
33-
}
34-
}, environment: MockInvalidEnvironment()
30+
if case let .throwError(error) = event.eventType {
31+
XCTAssertNotNil(error)
32+
XCTAssertEqual(error, .noUserPool)
33+
expectation.fulfill()
34+
}
35+
}, environment: MockInvalidEnvironment()
3536
)
3637

3738
await fulfillment(
@@ -45,25 +46,27 @@ class RefreshUserPoolTokensTests: XCTestCase {
4546
let expectation = expectation(description: "refreshUserPoolTokens")
4647
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
4748
MockIdentityProvider(
48-
mockInitiateAuthResponse: { _ in
49-
return InitiateAuthOutput()
49+
mockGetTokensFromRefreshTokenResponse: { _ in
50+
return GetTokensFromRefreshTokenOutput()
5051
}
5152
)
5253
}
5354

5455
let action = RefreshUserPoolTokens(existingSignedIndata: .testData)
5556

56-
await action.execute(withDispatcher: MockDispatcher { event in
57+
await action.execute(
58+
withDispatcher: MockDispatcher { event in
5759

58-
guard let event = event as? RefreshSessionEvent else { return }
60+
guard let event = event as? RefreshSessionEvent else { return }
5961

60-
if case let .throwError(error) = event.eventType {
61-
XCTAssertNotNil(error)
62-
XCTAssertEqual(error, .invalidTokens)
63-
expectation.fulfill()
64-
}
65-
}, environment: Defaults.makeDefaultAuthEnvironment(
66-
userPoolFactory: identityProviderFactory)
62+
if case let .throwError(error) = event.eventType {
63+
XCTAssertNotNil(error)
64+
XCTAssertEqual(error, .invalidTokens)
65+
expectation.fulfill()
66+
}
67+
},
68+
environment: Defaults.makeDefaultAuthEnvironment(
69+
userPoolFactory: identityProviderFactory)
6770
)
6871

6972
await fulfillment(
@@ -77,8 +80,8 @@ class RefreshUserPoolTokensTests: XCTestCase {
7780
let expectation = expectation(description: "refreshUserPoolTokens")
7881
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
7982
MockIdentityProvider(
80-
mockInitiateAuthResponse: { _ in
81-
return InitiateAuthOutput(
83+
mockGetTokensFromRefreshTokenResponse: { _ in
84+
return GetTokensFromRefreshTokenOutput(
8285
authenticationResult: .init(
8386
accessToken: "accessTokenNew",
8487
expiresIn: 100,
@@ -90,14 +93,17 @@ class RefreshUserPoolTokensTests: XCTestCase {
9093

9194
let action = RefreshUserPoolTokens(existingSignedIndata: .testData)
9295

93-
await action.execute(withDispatcher: MockDispatcher { event in
96+
await action.execute(
97+
withDispatcher: MockDispatcher { event in
9498

95-
if let userPoolEvent = event as? RefreshSessionEvent,
96-
case .refreshIdentityInfo = userPoolEvent.eventType {
97-
expectation.fulfill()
98-
}
99-
}, environment: Defaults.makeDefaultAuthEnvironment(
100-
userPoolFactory: identityProviderFactory)
99+
if let userPoolEvent = event as? RefreshSessionEvent,
100+
case .refreshIdentityInfo = userPoolEvent.eventType
101+
{
102+
expectation.fulfill()
103+
}
104+
},
105+
environment: Defaults.makeDefaultAuthEnvironment(
106+
userPoolFactory: identityProviderFactory)
101107
)
102108

103109
await fulfillment(
@@ -114,7 +120,7 @@ class RefreshUserPoolTokensTests: XCTestCase {
114120

115121
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
116122
MockIdentityProvider(
117-
mockInitiateAuthResponse: { _ in
123+
mockGetTokensFromRefreshTokenResponse: { _ in
118124
throw testError
119125
}
120126
)
@@ -128,20 +134,97 @@ class RefreshUserPoolTokensTests: XCTestCase {
128134

129135
let action = RefreshUserPoolTokens(existingSignedIndata: .testData)
130136

131-
await action.execute(withDispatcher: MockDispatcher { event in
137+
await action.execute(
138+
withDispatcher: MockDispatcher { event in
132139

133-
if let userPoolEvent = event as? RefreshSessionEvent,
134-
case let .throwError(error) = userPoolEvent.eventType {
135-
XCTAssertNotNil(error)
136-
XCTAssertEqual(error, .service(testError))
137-
expectation.fulfill()
138-
}
139-
}, environment: environment)
140+
if let userPoolEvent = event as? RefreshSessionEvent,
141+
case let .throwError(error) = userPoolEvent.eventType
142+
{
143+
XCTAssertNotNil(error)
144+
XCTAssertEqual(error, .service(testError))
145+
expectation.fulfill()
146+
}
147+
}, environment: environment)
140148

141149
await fulfillment(
142150
of: [expectation],
143151
timeout: 0.1
144152
)
145153
}
146154

155+
func testRefreshTokenRotation() async {
156+
157+
let expectation = expectation(description: "refreshTokenRotation")
158+
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
159+
MockIdentityProvider(
160+
mockGetTokensFromRefreshTokenResponse: { _ in
161+
return GetTokensFromRefreshTokenOutput(
162+
authenticationResult: .init(
163+
accessToken: "accessTokenNew",
164+
expiresIn: 100,
165+
idToken: "idTokenNew",
166+
refreshToken: "refreshTokenRotated"))
167+
}
168+
)
169+
}
170+
171+
let action = RefreshUserPoolTokens(existingSignedIndata: .testData)
172+
173+
await action.execute(
174+
withDispatcher: MockDispatcher { event in
175+
176+
if let userPoolEvent = event as? RefreshSessionEvent,
177+
case let .refreshIdentityInfo(signedInData, _) = userPoolEvent.eventType
178+
{
179+
XCTAssertEqual(
180+
signedInData.cognitoUserPoolTokens.refreshToken, "refreshTokenRotated")
181+
expectation.fulfill()
182+
}
183+
},
184+
environment: Defaults.makeDefaultAuthEnvironment(
185+
userPoolFactory: identityProviderFactory)
186+
)
187+
188+
await fulfillment(
189+
of: [expectation],
190+
timeout: 0.1
191+
)
192+
}
193+
func testRefreshTokenMissing() async {
194+
195+
let expectation = expectation(description: "refreshTokenMissing")
196+
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
197+
MockIdentityProvider(
198+
mockGetTokensFromRefreshTokenResponse: { _ in
199+
return GetTokensFromRefreshTokenOutput(
200+
authenticationResult: .init(
201+
accessToken: "accessTokenNew",
202+
expiresIn: 100,
203+
idToken: "idTokenNew",
204+
refreshToken: nil))
205+
}
206+
)
207+
}
208+
209+
let action = RefreshUserPoolTokens(existingSignedIndata: .testData)
210+
211+
await action.execute(
212+
withDispatcher: MockDispatcher { event in
213+
214+
if let userPoolEvent = event as? RefreshSessionEvent,
215+
case let .throwError(error) = userPoolEvent.eventType
216+
{
217+
XCTAssertEqual(error, .invalidTokens)
218+
expectation.fulfill()
219+
}
220+
},
221+
environment: Defaults.makeDefaultAuthEnvironment(
222+
userPoolFactory: identityProviderFactory)
223+
)
224+
225+
await fulfillment(
226+
of: [expectation],
227+
timeout: 0.1
228+
)
229+
}
147230
}

AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/HubEventTests/AuthHubEventHandlerTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ class AuthHubEventHandlerTests: XCTestCase {
423423
.notStarted)
424424

425425
let mockIdentityProvider = MockIdentityProvider(
426-
mockInitiateAuthResponse: { _ in
426+
mockGetTokensFromRefreshTokenResponse: { _ in
427427
throw AWSCognitoIdentityProvider.NotAuthorizedException()
428428
})
429429

0 commit comments

Comments
 (0)