Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
// SPDX-License-Identifier: Apache-2.0
//

import Amplify
import AWSPluginsCore
import AWSCognitoIdentityProvider
import Foundation
import AWSPluginsCore
import Amplify
import ClientRuntime
import Foundation

struct RefreshUserPoolTokens: Action {

Expand All @@ -28,7 +28,6 @@ struct RefreshUserPoolTokens: Action {
return
}

let authEnv = try environment.authEnvironment()
let config = environment.userPoolConfiguration
let client = try? environment.cognitoUserPoolFactory()
let existingTokens = existingSignedIndata.cognitoUserPoolTokens
Expand All @@ -37,29 +36,35 @@ struct RefreshUserPoolTokens: Action {
for: existingSignedIndata.username,
with: environment)

let asfDeviceId = try await CognitoUserPoolASF.asfDeviceID(
for: existingSignedIndata.username,
credentialStoreClient: authEnv.credentialsClient)
let deviceKey: String? = {
if case .metadata(let data) = deviceMetadata {
return data.deviceKey
}
return nil
}()

let input = await InitiateAuthInput.refreshAuthInput(
username: existingSignedIndata.username,
refreshToken: existingTokens.refreshToken,
let input = GetTokensFromRefreshTokenInput(
clientId: config.clientId,
clientMetadata: [:],
asfDeviceId: asfDeviceId,
deviceMetadata: deviceMetadata,
environment: environment)
clientSecret: config.clientSecret,
deviceKey: deviceKey,
refreshToken: existingTokens.refreshToken
)

logVerbose("\(#fileID) Starting initiate auth refresh token", environment: environment)
logVerbose(
"\(#fileID) Starting get tokens from refresh token", environment: environment)

let response = try await client?.initiateAuth(input: input)
let response = try await client?.getTokensFromRefreshToken(input: input)

logVerbose("\(#fileID) Initiate auth response received", environment: environment)
logVerbose(
"\(#fileID) Get tokens from refresh token response received",
environment: environment)

guard let authenticationResult = response?.authenticationResult,
let idToken = authenticationResult.idToken,
let accessToken = authenticationResult.accessToken
let idToken = authenticationResult.idToken,
let accessToken = authenticationResult.accessToken,
let refreshToken = authenticationResult.refreshToken
else {

let event = RefreshSessionEvent(eventType: .throwError(.invalidTokens))
await dispatcher.send(event)
logVerbose("\(#fileID) Sending event \(event.type)", environment: environment)
Expand All @@ -69,9 +74,9 @@ struct RefreshUserPoolTokens: Action {
let userPoolTokens = AWSCognitoUserPoolTokens(
idToken: idToken,
accessToken: accessToken,
refreshToken: existingTokens.refreshToken,
expiresIn: authenticationResult.expiresIn
refreshToken: refreshToken
)

let signedInData = SignedInData(
signedInDate: existingSignedIndata.signedInDate,
signInMethod: existingSignedIndata.signInMethod,
Expand All @@ -96,13 +101,14 @@ struct RefreshUserPoolTokens: Action {
await dispatcher.send(event)
}

logVerbose("\(#fileID) Initiate auth complete", environment: environment)
logVerbose("\(#fileID) Get tokens from refresh token complete", environment: environment)
}
}

extension RefreshUserPoolTokens: DefaultLogger {
public static var log: Logger {
Amplify.Logging.logger(forCategory: CategoryType.auth.displayName, forNamespace: String(describing: self))
Amplify.Logging.logger(
forCategory: CategoryType.auth.displayName, forNamespace: String(describing: self))
}

public var log: Logger {
Expand All @@ -114,7 +120,7 @@ extension RefreshUserPoolTokens: CustomDebugDictionaryConvertible {
var debugDictionary: [String: Any] {
[
"identifier": identifier,
"existingSignedInData": existingSignedIndata
"existingSignedInData": existingSignedIndata,
]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ protocol CognitoUserPoolBehavior {
/// Throws RevokeTokenOutputError
func revokeToken(input: RevokeTokenInput) async throws -> RevokeTokenOutput

/// Throws GetTokensFromRefreshTokenOutputError
func getTokensFromRefreshToken(input: GetTokensFromRefreshTokenInput) async throws -> GetTokensFromRefreshTokenOutput

// MARK: - User Attribute API's

/// Throws GetUserAttributeVerificationCodeOutputError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,6 @@ extension InitiateAuthInput {
environment: environment)
}

static func refreshAuthInput(username: String,
refreshToken: String,
clientMetadata: [String: String],
asfDeviceId: String,
deviceMetadata: DeviceMetadata,
environment: UserPoolEnvironment) async -> InitiateAuthInput {

let authParameters = [
"REFRESH_TOKEN": refreshToken
]

return await buildInput(username: username,
authFlowType: .refreshTokenAuth,
authParameters: authParameters,
clientMetadata: clientMetadata,
asfDeviceId: asfDeviceId,
deviceMetadata: deviceMetadata,
environment: environment)

}

static func buildInput(username: String,
authFlowType: CognitoIdentityProviderClientTypes.AuthFlowType,
authParameters: [String: String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
// SPDX-License-Identifier: Apache-2.0
//

import XCTest
import Amplify
import AWSPluginsCore
import AWSCognitoIdentityProvider
import AWSPluginsCore
import Amplify
import XCTest

@testable import AWSCognitoAuthPlugin

Expand All @@ -20,18 +20,19 @@ class RefreshUserPoolTokensTests: XCTestCase {

let action = RefreshUserPoolTokens(existingSignedIndata: .testData)

await action.execute(withDispatcher: MockDispatcher { event in
await action.execute(
withDispatcher: MockDispatcher { event in

guard let event = event as? RefreshSessionEvent else {
return
}
guard let event = event as? RefreshSessionEvent else {
return
}

if case let .throwError(error) = event.eventType {
XCTAssertNotNil(error)
XCTAssertEqual(error, .noUserPool)
expectation.fulfill()
}
}, environment: MockInvalidEnvironment()
if case let .throwError(error) = event.eventType {
XCTAssertNotNil(error)
XCTAssertEqual(error, .noUserPool)
expectation.fulfill()
}
}, environment: MockInvalidEnvironment()
)

await fulfillment(
Expand All @@ -45,25 +46,27 @@ class RefreshUserPoolTokensTests: XCTestCase {
let expectation = expectation(description: "refreshUserPoolTokens")
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockInitiateAuthResponse: { _ in
return InitiateAuthOutput()
mockGetTokensFromRefreshTokenResponse: { _ in
return GetTokensFromRefreshTokenOutput()
}
)
}

let action = RefreshUserPoolTokens(existingSignedIndata: .testData)

await action.execute(withDispatcher: MockDispatcher { event in
await action.execute(
withDispatcher: MockDispatcher { event in

guard let event = event as? RefreshSessionEvent else { return }
guard let event = event as? RefreshSessionEvent else { return }

if case let .throwError(error) = event.eventType {
XCTAssertNotNil(error)
XCTAssertEqual(error, .invalidTokens)
expectation.fulfill()
}
}, environment: Defaults.makeDefaultAuthEnvironment(
userPoolFactory: identityProviderFactory)
if case let .throwError(error) = event.eventType {
XCTAssertNotNil(error)
XCTAssertEqual(error, .invalidTokens)
expectation.fulfill()
}
},
environment: Defaults.makeDefaultAuthEnvironment(
userPoolFactory: identityProviderFactory)
)

await fulfillment(
Expand All @@ -77,8 +80,8 @@ class RefreshUserPoolTokensTests: XCTestCase {
let expectation = expectation(description: "refreshUserPoolTokens")
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockInitiateAuthResponse: { _ in
return InitiateAuthOutput(
mockGetTokensFromRefreshTokenResponse: { _ in
return GetTokensFromRefreshTokenOutput(
authenticationResult: .init(
accessToken: "accessTokenNew",
expiresIn: 100,
Expand All @@ -90,14 +93,17 @@ class RefreshUserPoolTokensTests: XCTestCase {

let action = RefreshUserPoolTokens(existingSignedIndata: .testData)

await action.execute(withDispatcher: MockDispatcher { event in
await action.execute(
withDispatcher: MockDispatcher { event in

if let userPoolEvent = event as? RefreshSessionEvent,
case .refreshIdentityInfo = userPoolEvent.eventType {
expectation.fulfill()
}
}, environment: Defaults.makeDefaultAuthEnvironment(
userPoolFactory: identityProviderFactory)
if let userPoolEvent = event as? RefreshSessionEvent,
case .refreshIdentityInfo = userPoolEvent.eventType
{
expectation.fulfill()
}
},
environment: Defaults.makeDefaultAuthEnvironment(
userPoolFactory: identityProviderFactory)
)

await fulfillment(
Expand All @@ -114,7 +120,7 @@ class RefreshUserPoolTokensTests: XCTestCase {

let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockInitiateAuthResponse: { _ in
mockGetTokensFromRefreshTokenResponse: { _ in
throw testError
}
)
Expand All @@ -128,20 +134,97 @@ class RefreshUserPoolTokensTests: XCTestCase {

let action = RefreshUserPoolTokens(existingSignedIndata: .testData)

await action.execute(withDispatcher: MockDispatcher { event in
await action.execute(
withDispatcher: MockDispatcher { event in

if let userPoolEvent = event as? RefreshSessionEvent,
case let .throwError(error) = userPoolEvent.eventType {
XCTAssertNotNil(error)
XCTAssertEqual(error, .service(testError))
expectation.fulfill()
}
}, environment: environment)
if let userPoolEvent = event as? RefreshSessionEvent,
case let .throwError(error) = userPoolEvent.eventType
{
XCTAssertNotNil(error)
XCTAssertEqual(error, .service(testError))
expectation.fulfill()
}
}, environment: environment)

await fulfillment(
of: [expectation],
timeout: 0.1
)
}

func testRefreshTokenRotation() async {

let expectation = expectation(description: "refreshTokenRotation")
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockGetTokensFromRefreshTokenResponse: { _ in
return GetTokensFromRefreshTokenOutput(
authenticationResult: .init(
accessToken: "accessTokenNew",
expiresIn: 100,
idToken: "idTokenNew",
refreshToken: "refreshTokenRotated"))
}
)
}

let action = RefreshUserPoolTokens(existingSignedIndata: .testData)

await action.execute(
withDispatcher: MockDispatcher { event in

if let userPoolEvent = event as? RefreshSessionEvent,
case let .refreshIdentityInfo(signedInData, _) = userPoolEvent.eventType
{
XCTAssertEqual(
signedInData.cognitoUserPoolTokens.refreshToken, "refreshTokenRotated")
expectation.fulfill()
}
},
environment: Defaults.makeDefaultAuthEnvironment(
userPoolFactory: identityProviderFactory)
)

await fulfillment(
of: [expectation],
timeout: 0.1
)
}
func testRefreshTokenMissing() async {

let expectation = expectation(description: "refreshTokenMissing")
let identityProviderFactory: BasicSRPAuthEnvironment.CognitoUserPoolFactory = {
MockIdentityProvider(
mockGetTokensFromRefreshTokenResponse: { _ in
return GetTokensFromRefreshTokenOutput(
authenticationResult: .init(
accessToken: "accessTokenNew",
expiresIn: 100,
idToken: "idTokenNew",
refreshToken: nil))
}
)
}

let action = RefreshUserPoolTokens(existingSignedIndata: .testData)

await action.execute(
withDispatcher: MockDispatcher { event in

if let userPoolEvent = event as? RefreshSessionEvent,
case let .throwError(error) = userPoolEvent.eventType
{
XCTAssertEqual(error, .invalidTokens)
expectation.fulfill()
}
},
environment: Defaults.makeDefaultAuthEnvironment(
userPoolFactory: identityProviderFactory)
)

await fulfillment(
of: [expectation],
timeout: 0.1
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ class AuthHubEventHandlerTests: XCTestCase {
.notStarted)

let mockIdentityProvider = MockIdentityProvider(
mockInitiateAuthResponse: { _ in
mockGetTokensFromRefreshTokenResponse: { _ in
throw AWSCognitoIdentityProvider.NotAuthorizedException()
})

Expand Down
Loading
Loading