Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,34 @@ struct RefreshUserPoolTokens: Action {
return
}

let authEnv = try environment.authEnvironment()
let config = environment.userPoolConfiguration
let client = try? environment.cognitoUserPoolFactory()
let existingTokens = existingSignedIndata.cognitoUserPoolTokens

let deviceMetadata = await DeviceMetadataHelper.getDeviceMetadata(
for: existingSignedIndata.username,
with: environment)

let asfDeviceId = try await CognitoUserPoolASF.asfDeviceID(
for: existingSignedIndata.username,
credentialStoreClient: authEnv.credentialsClient)

let input = await InitiateAuthInput.refreshAuthInput(
username: existingSignedIndata.username,
refreshToken: existingTokens.refreshToken,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove the implementation of refreshAuthInput method as its not longer required.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is already removed in this PR unless I missed something

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the implementation in InitiateAuthInput+Amplify.swift

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it


let deviceKey: String? = {
if case .metadata(let data) = deviceMetadata {
return data.deviceKey
}
return nil
}()

let input = GetTokensFromRefreshTokenInput(
clientId: config.clientId,
clientMetadata: [:],
asfDeviceId: asfDeviceId,
deviceMetadata: deviceMetadata,
environment: environment)
clientSecret: config.clientSecret,
deviceKey: deviceKey,
refreshToken: existingTokens.refreshToken
)

logVerbose("\(#fileID) Starting initiate auth refresh token", environment: environment)
logVerbose("\(#fileID) Starting get tokens from refresh token", environment: environment)

let response = try await client?.initiateAuth(input: input)
let response = try await client?.getTokensFromRefreshToken(input: input)

logVerbose("\(#fileID) Initiate auth response received", environment: environment)
logVerbose("\(#fileID) Get tokens from refresh token response received", environment: environment)

guard let authenticationResult = response?.authenticationResult,
let idToken = authenticationResult.idToken,
Expand All @@ -69,8 +71,7 @@ struct RefreshUserPoolTokens: Action {
let userPoolTokens = AWSCognitoUserPoolTokens(
idToken: idToken,
accessToken: accessToken,
refreshToken: existingTokens.refreshToken,
expiresIn: authenticationResult.expiresIn
refreshToken: authenticationResult.refreshToken ?? existingTokens.refreshToken
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
refreshToken: existingTokens.refreshToken,
expiresIn: authenticationResult.expiresIn
refreshToken: authenticationResult.refreshToken ?? existingTokens.refreshToken
guard let authenticationResult = response?.authenticationResult,
let idToken = authenticationResult.idToken,
let accessToken = authenticationResult.accessToken,
let refreshToken = authenticationResult.refreshToken. <------- change
else {
let event = RefreshSessionEvent(eventType: .throwError(.invalidTokens))
await dispatcher.send(event)
logVerbose("\(#fileID) Sending event \(event.type)", environment: environment)
return
}
let userPoolTokens = AWSCognitoUserPoolTokens(
idToken: idToken,
accessToken: accessToken,
refreshToken: refreshToken. <------- change

I don't think we should use the old refresh token at this point.. Should throw an error if the refresh token doesn't exist, as we have entered an unknown state because of missing tokens.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, there's no use for the old token, this was here because its possible the user doesn't have rotation enabled but in that case the returned token would just be their one token anyways

)
let signedInData = SignedInData(
signedInDate: existingSignedIndata.signedInDate,
Expand All @@ -96,7 +97,7 @@ struct RefreshUserPoolTokens: Action {
await dispatcher.send(event)
}

logVerbose("\(#fileID) Initiate auth complete", environment: environment)
logVerbose("\(#fileID) Get tokens from refresh token complete", environment: environment)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ protocol CognitoUserPoolBehavior {
/// Throws RevokeTokenOutputError
func revokeToken(input: RevokeTokenInput) async throws -> RevokeTokenOutput

/// Throws GetTokensFromRefreshTokenOutputError
func getTokensFromRefreshToken(input: GetTokensFromRefreshTokenInput) async throws -> GetTokensFromRefreshTokenOutput

// MARK: - User Attribute API's

/// Throws GetUserAttributeVerificationCodeOutputError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class RefreshUserPoolTokensTests: XCTestCase {
let expectation = expectation(description: "refreshUserPoolTokens")
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockInitiateAuthResponse: { _ in
return InitiateAuthOutput()
mockGetTokensFromRefreshTokenResponse: { _ in
return GetTokensFromRefreshTokenOutput()
}
)
}
Expand Down Expand Up @@ -77,8 +77,8 @@ class RefreshUserPoolTokensTests: XCTestCase {
let expectation = expectation(description: "refreshUserPoolTokens")
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockInitiateAuthResponse: { _ in
return InitiateAuthOutput(
mockGetTokensFromRefreshTokenResponse: { _ in
return GetTokensFromRefreshTokenOutput(
authenticationResult: .init(
accessToken: "accessTokenNew",
expiresIn: 100,
Expand Down Expand Up @@ -114,7 +114,7 @@ class RefreshUserPoolTokensTests: XCTestCase {

let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockInitiateAuthResponse: { _ in
mockGetTokensFromRefreshTokenResponse: { _ in
throw testError
}
)
Expand Down Expand Up @@ -144,4 +144,74 @@ class RefreshUserPoolTokensTests: XCTestCase {
)
}

func testRefreshTokenRotation() async {

let expectation = expectation(description: "refreshTokenRotation")
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockGetTokensFromRefreshTokenResponse: { _ in
return GetTokensFromRefreshTokenOutput(
authenticationResult: .init(
accessToken: "accessTokenNew",
expiresIn: 100,
idToken: "idTokenNew",
refreshToken: "refreshTokenRotated"))
}
)
}

let action = RefreshUserPoolTokens(existingSignedIndata: .testData)

await action.execute(withDispatcher: MockDispatcher { event in

if let userPoolEvent = event as? RefreshSessionEvent,
case let .refreshIdentityInfo(signedInData, _) = userPoolEvent.eventType {
XCTAssertEqual(signedInData.cognitoUserPoolTokens.refreshToken, "refreshTokenRotated")
expectation.fulfill()
}
}, environment: Defaults.makeDefaultAuthEnvironment(
userPoolFactory: identityProviderFactory)
)

await fulfillment(
of: [expectation],
timeout: 0.1
)
}

func testRefreshTokenNoRotation() async {

let expectation = expectation(description: "refreshTokenNoRotation")
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockGetTokensFromRefreshTokenResponse: { _ in
return GetTokensFromRefreshTokenOutput(
authenticationResult: .init(
accessToken: "accessTokenNew",
expiresIn: 100,
idToken: "idTokenNew",
refreshToken: nil))
Comment on lines +199 to +204
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the refresh token is nil, can that ever happen.. I personally think it shouldn't and this test case should throw and error instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, while the API can return a null token we should just throw an error in that case

}
)
}

let action = RefreshUserPoolTokens(existingSignedIndata: .testData)

await action.execute(withDispatcher: MockDispatcher { event in

if let userPoolEvent = event as? RefreshSessionEvent,
case let .refreshIdentityInfo(signedInData, _) = userPoolEvent.eventType {
XCTAssertEqual(signedInData.cognitoUserPoolTokens.refreshToken, "refreshToken")
expectation.fulfill()
}
}, environment: Defaults.makeDefaultAuthEnvironment(
userPoolFactory: identityProviderFactory)
)

await fulfillment(
of: [expectation],
timeout: 0.1
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ class AuthHubEventHandlerTests: XCTestCase {
.notStarted)

let mockIdentityProvider = MockIdentityProvider(
mockInitiateAuthResponse: { _ in
mockGetTokensFromRefreshTokenResponse: { _ in
throw AWSCognitoIdentityProvider.NotAuthorizedException()
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
typealias MockInitiateAuthResponse = (InitiateAuthInput) async throws
-> InitiateAuthOutput

typealias MockGetTokensFromRefreshTokenResponse = (GetTokensFromRefreshTokenInput) async throws
-> GetTokensFromRefreshTokenOutput

typealias MockConfirmSignUpResponse = (ConfirmSignUpInput) async throws
-> ConfirmSignUpOutput

Expand Down Expand Up @@ -88,6 +91,7 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
let mockSignUpResponse: MockSignUpResponse?
let mockRevokeTokenResponse: MockRevokeTokenResponse?
let mockInitiateAuthResponse: MockInitiateAuthResponse?
let mockGetTokensFromRefreshTokenResponse: MockGetTokensFromRefreshTokenResponse?
let mockGlobalSignOutResponse: MockGlobalSignOutResponse?
let mockConfirmSignUpResponse: MockConfirmSignUpResponse?
let mockRespondToAuthChallengeResponse: MockRespondToAuthChallengeResponse?
Expand Down Expand Up @@ -116,6 +120,7 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
mockSignUpResponse: MockSignUpResponse? = nil,
mockRevokeTokenResponse: MockRevokeTokenResponse? = nil,
mockInitiateAuthResponse: MockInitiateAuthResponse? = nil,
mockGetTokensFromRefreshTokenResponse: MockGetTokensFromRefreshTokenResponse? = nil,
mockGlobalSignOutResponse: MockGlobalSignOutResponse? = nil,
mockConfirmSignUpResponse: MockConfirmSignUpResponse? = nil,
mockRespondToAuthChallengeResponse: MockRespondToAuthChallengeResponse? = nil,
Expand All @@ -139,6 +144,7 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
self.mockSignUpResponse = mockSignUpResponse
self.mockRevokeTokenResponse = mockRevokeTokenResponse
self.mockInitiateAuthResponse = mockInitiateAuthResponse
self.mockGetTokensFromRefreshTokenResponse = mockGetTokensFromRefreshTokenResponse
self.mockGlobalSignOutResponse = mockGlobalSignOutResponse
self.mockConfirmSignUpResponse = mockConfirmSignUpResponse
self.mockRespondToAuthChallengeResponse = mockRespondToAuthChallengeResponse
Expand Down Expand Up @@ -192,6 +198,11 @@ struct MockIdentityProvider: CognitoUserPoolBehavior {
return try await mockRevokeTokenResponse!(input)
}

/// Throws GetTokensFromRefreshTokenOutputError
func getTokensFromRefreshToken(input: GetTokensFromRefreshTokenInput) async throws -> GetTokensFromRefreshTokenOutput {
return try await mockGetTokensFromRefreshTokenResponse!(input)
}

func getUserAttributeVerificationCode(input: GetUserAttributeVerificationCodeInput) async throws -> GetUserAttributeVerificationCodeOutput {
return try await mockGetUserAttributeVerificationCodeOutput!(input)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
AuthorizationState.sessionEstablished(
AmplifyCredentials.testData),
.notStarted)
let initAuth: MockIdentityProvider.MockInitiateAuthResponse = { _ in
let getTokensFromRefreshToken: MockIdentityProvider.MockGetTokensFromRefreshTokenResponse = { _ in
resultExpectation.fulfill()
return InitiateAuthOutput(authenticationResult: .init(
return GetTokensFromRefreshTokenOutput(authenticationResult: .init(
accessToken: "accessToken",
expiresIn: 1000,
idToken: "idToken",
Expand All @@ -115,7 +115,7 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
}

let plugin = configurePluginWith(
userPool: { MockIdentityProvider(mockInitiateAuthResponse: initAuth) },
userPool: { MockIdentityProvider(mockGetTokensFromRefreshTokenResponse: getTokensFromRefreshToken) },
identityPool: { MockIdentity(mockGetCredentialsResponse: awsCredentials) },
initialState: initialState)
let session = try await plugin.fetchAuthSession(options: .forceRefresh())
Expand Down Expand Up @@ -212,11 +212,11 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
AmplifyCredentials.testDataWithExpiredTokens),
.notStarted)

let initAuth: MockIdentityProvider.MockInitiateAuthResponse = { _ in
let getTokensFromRefreshToken: MockIdentityProvider.MockGetTokensFromRefreshTokenResponse = { _ in
throw AWSCognitoIdentityProvider.NotAuthorizedException()
}

let plugin = configurePluginWith(userPool: { MockIdentityProvider(mockInitiateAuthResponse: initAuth) }, initialState: initialState)
let plugin = configurePluginWith(userPool: { MockIdentityProvider(mockGetTokensFromRefreshTokenResponse: getTokensFromRefreshToken) }, initialState: initialState)
let session = try await plugin.fetchAuthSession(options: AuthFetchSessionRequest.Options())
XCTAssertTrue(session.isSignedIn)

Expand Down Expand Up @@ -261,8 +261,8 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
AmplifyCredentials.testDataWithExpiredTokens),
.notStarted)

let initAuth: MockIdentityProvider.MockInitiateAuthResponse = { _ in
return InitiateAuthOutput(authenticationResult: .init(accessToken: "accessToken",
let getTokensFromRefreshToken: MockIdentityProvider.MockGetTokensFromRefreshTokenResponse = { _ in
return GetTokensFromRefreshTokenOutput(authenticationResult: .init(accessToken: "accessToken",
expiresIn: 1000,
idToken: "idToken",
refreshToken: "refreshToke"))
Expand All @@ -273,7 +273,7 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
}

let plugin = configurePluginWith(
userPool: { MockIdentityProvider(mockInitiateAuthResponse: initAuth) },
userPool: { MockIdentityProvider(mockGetTokensFromRefreshTokenResponse: getTokensFromRefreshToken) },
identityPool: { MockIdentity(mockGetCredentialsResponse: awsCredentials) },
initialState: initialState)

Expand Down Expand Up @@ -494,15 +494,15 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
AmplifyCredentials.testDataWithExpiredTokens),
.notStarted)

let initAuth: MockIdentityProvider.MockInitiateAuthResponse = { _ in
return InitiateAuthOutput(authenticationResult: .init(accessToken: nil,
let refreshTokenAuth: MockIdentityProvider.MockGetTokensFromRefreshTokenResponse = { _ in
return GetTokensFromRefreshTokenOutput(authenticationResult: .init(accessToken: nil,
expiresIn: 1000,
idToken: "idToken",
refreshToken: "refreshToke"))
}

let plugin = configurePluginWith(
userPool: { MockIdentityProvider(mockInitiateAuthResponse: initAuth) },
userPool: { MockIdentityProvider(mockGetTokensFromRefreshTokenResponse: refreshTokenAuth) },
initialState: initialState)

let session = try await plugin.fetchAuthSession(options: AuthFetchSessionRequest.Options())
Expand Down Expand Up @@ -548,8 +548,8 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
AmplifyCredentials.testDataWithExpiredTokens),
.notStarted)

let initAuth: MockIdentityProvider.MockInitiateAuthResponse = { _ in
return InitiateAuthOutput(authenticationResult: .init(accessToken: "accessToken",
let refreshTokenAuth: MockIdentityProvider.MockGetTokensFromRefreshTokenResponse = { _ in
return GetTokensFromRefreshTokenOutput(authenticationResult: .init(accessToken: "accessToken",
expiresIn: 1000,
idToken: "idToken",
refreshToken: "refreshToke"))
Expand All @@ -559,7 +559,7 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
return GetCredentialsForIdentityOutput(credentials: nil, identityId: "ss")
}
let plugin = configurePluginWith(
userPool: { MockIdentityProvider(mockInitiateAuthResponse: initAuth) },
userPool: { MockIdentityProvider(mockGetTokensFromRefreshTokenResponse: refreshTokenAuth) },
identityPool: { MockIdentity(mockGetCredentialsResponse: awsCredentials) },
initialState: initialState)

Expand Down Expand Up @@ -714,12 +714,12 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
AmplifyCredentials.testDataWithExpiredTokens),
.notStarted)

let initAuth: MockIdentityProvider.MockInitiateAuthResponse = { _ in
let refreshTokenAuth: MockIdentityProvider.MockGetTokensFromRefreshTokenResponse = { _ in
throw AWSCognitoIdentityProvider.NotAuthorizedException(message: "NotAuthorized")
}

let plugin = configurePluginWith(
userPool: { MockIdentityProvider(mockInitiateAuthResponse: initAuth) },
userPool: { MockIdentityProvider(mockGetTokensFromRefreshTokenResponse: refreshTokenAuth) },
initialState: initialState)

let session = try await plugin.fetchAuthSession(options: AuthFetchSessionRequest.Options())
Expand Down Expand Up @@ -816,8 +816,8 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
AmplifyCredentials.testDataWithExpiredTokens),
.notStarted)

let initAuth: MockIdentityProvider.MockInitiateAuthResponse = { _ in
return InitiateAuthOutput(authenticationResult: .init(accessToken: "accessToken",
let refreshTokenAuth: MockIdentityProvider.MockGetTokensFromRefreshTokenResponse = { _ in
return GetTokensFromRefreshTokenOutput(authenticationResult: .init(accessToken: "accessToken",
expiresIn: 1000,
idToken: "idToken",
refreshToken: "refreshToke"))
Expand All @@ -827,7 +827,7 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
throw NSError(domain: NSURLErrorDomain, code: 1, userInfo: nil)
}
let plugin = configurePluginWith(
userPool: { MockIdentityProvider(mockInitiateAuthResponse: initAuth) },
userPool: { MockIdentityProvider(mockGetTokensFromRefreshTokenResponse: refreshTokenAuth) },
identityPool: { MockIdentity(mockGetCredentialsResponse: awsCredentials) },
initialState: initialState)

Expand Down
8 changes: 4 additions & 4 deletions Package.resolved

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading