Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import com.amplifyframework.auth.result.step.AuthNextSignUpStep
import com.amplifyframework.auth.result.step.AuthSignUpStep
import com.amplifyframework.statemachine.Action
import com.amplifyframework.statemachine.codegen.actions.SignUpActions
import com.amplifyframework.statemachine.codegen.data.SignUpData
import com.amplifyframework.statemachine.codegen.events.SignUpEvent

internal object SignUpCognitoActions : SignUpActions {
Expand Down Expand Up @@ -69,12 +68,9 @@ internal object SignUpCognitoActions : SignUpActions {
}

val codeDeliveryDetails = response?.codeDeliveryDetails.toAuthCodeDeliveryDetails()
val signUpData = SignUpData(
username,
event.signUpData.validationData,
event.signUpData.clientMetadata,
response?.session,
response?.userSub
val signUpData = event.signUpData.copy(
session = response?.session,
userId = response?.userSub
)
if (response?.userConfirmed == true) {
var signUpStep = AuthSignUpStep.DONE
Expand Down Expand Up @@ -106,7 +102,7 @@ internal object SignUpCognitoActions : SignUpActions {
SignUpEvent(SignUpEvent.EventType.InitiateSignUpComplete(signUpData, signUpResult))
}
} catch (e: Exception) {
SignUpEvent(SignUpEvent.EventType.ThrowError(e))
SignUpEvent(SignUpEvent.EventType.ThrowError(event.signUpData, e))
}
logger.verbose("$id Sending event ${evt.type}")
dispatcher.send(evt)
Expand Down Expand Up @@ -136,13 +132,7 @@ internal object SignUpCognitoActions : SignUpActions {
this.clientMetadata = event.signUpData.clientMetadata
this.session = event.signUpData.session
}
val signUpData = SignUpData(
username,
event.signUpData.validationData,
event.signUpData.clientMetadata,
response?.session,
event.signUpData.userId
)
val signUpData = event.signUpData.copy(session = response?.session)
var signUpStep = AuthSignUpStep.DONE
if (response?.session != null) {
signUpStep = AuthSignUpStep.COMPLETE_AUTO_SIGN_IN
Expand All @@ -159,7 +149,7 @@ internal object SignUpCognitoActions : SignUpActions {
)
SignUpEvent(SignUpEvent.EventType.SignedUp(signUpData, signUpResult))
} catch (e: Exception) {
SignUpEvent(SignUpEvent.EventType.ThrowError(e))
SignUpEvent(SignUpEvent.EventType.ThrowError(event.signUpData, e))
}
logger.verbose("$id Sending event ${evt.type}")
dispatcher.send(evt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.amplifyframework.auth.result.AuthSignUpResult
import com.amplifyframework.statemachine.codegen.data.SignUpData
import com.amplifyframework.statemachine.codegen.events.SignUpEvent
import com.amplifyframework.statemachine.codegen.states.SignUpState
import com.amplifyframework.statemachine.codegen.states.getSignUpData
import kotlinx.coroutines.flow.drop
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.onSubscription
Expand All @@ -37,21 +38,25 @@ internal class ConfirmSignUpUseCase(private val stateMachine: AuthStateMachine)
): AuthSignUpResult {
stateMachine.throwIfNotConfigured()

val startingState = stateMachine.getCurrentState().authSignUpState
val currentState = stateMachine.getCurrentState()
val existingSignUpData = currentState.authSignUpState?.getSignUpData()

val clientMetadata = (options as? AWSCognitoAuthConfirmSignUpOptions)?.clientMetadata
val signUpData = if (existingSignUpData?.username == username) {
existingSignUpData.copy(clientMetadata = clientMetadata)
} else {
SignUpData(
username = username,
validationData = null,
clientMetadata = clientMetadata,
session = null,
userId = null
)
}

val result = stateMachine.state
.onSubscription {
var userId: String? = null
var session: String? = null
if (startingState is SignUpState.AwaitingUserConfirmation &&
startingState.signUpData.username == username
) {
session = startingState.signUpData.session
userId = startingState.signUpResult.userId
}
val clientMetadata = (options as? AWSCognitoAuthConfirmSignUpOptions)?.clientMetadata
val signupData = SignUpData(username, null, clientMetadata, session, userId)
val event = SignUpEvent(SignUpEvent.EventType.ConfirmSignUp(signupData, confirmationCode))
val event = SignUpEvent(SignUpEvent.EventType.ConfirmSignUp(signUpData, confirmationCode))
stateMachine.send(event)
}
.drop(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ internal class SignUpEvent(

data class SignedUp(val signUpData: SignUpData, val signUpResult: AuthSignUpResult) : EventType()

data class ThrowError(val exception: Exception) : EventType()
data class ThrowError(val signUpData: SignUpData, val exception: Exception) : EventType()
}

override val type: String = eventType.javaClass.simpleName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ internal sealed class SignUpState : State {
data class AwaitingUserConfirmation(val signUpData: SignUpData, val signUpResult: AuthSignUpResult) : SignUpState()
data class ConfirmingSignUp(val signUpData: SignUpData) : SignUpState()
data class SignedUp(val signUpData: SignUpData, val signUpResult: AuthSignUpResult) : SignUpState()
data class Error(val exception: Exception, var hasNewResponse: Boolean = true) : SignUpState()
data class Error(
val signUpData: SignUpData,
val exception: Exception,
var hasNewResponse: Boolean = true
) : SignUpState()

class Resolver(private val signUpActions: SignUpActions) :
StateMachineResolver<SignUpState> {
Expand All @@ -56,7 +60,7 @@ internal sealed class SignUpState : State {
)
}
is SignUpEvent.EventType.ThrowError -> {
StateResolution(Error(signUpEvent.exception))
StateResolution(Error(signUpEvent.signUpData, signUpEvent.exception))
}
else -> defaultResolution
}
Expand All @@ -82,7 +86,7 @@ internal sealed class SignUpState : State {
)
}
is SignUpEvent.EventType.ThrowError -> {
StateResolution(Error(signUpEvent.exception))
StateResolution(Error(signUpEvent.signUpData, signUpEvent.exception))
}
else -> defaultResolution
}
Expand All @@ -100,7 +104,7 @@ internal sealed class SignUpState : State {
)
}
is SignUpEvent.EventType.ThrowError -> {
StateResolution(Error(signUpEvent.exception))
StateResolution(Error(signUpEvent.signUpData, signUpEvent.exception))
}
else -> defaultResolution
}
Expand All @@ -121,7 +125,7 @@ internal sealed class SignUpState : State {
StateResolution(SignedUp(signUpEvent.signUpData, signUpEvent.signUpResult))
}
is SignUpEvent.EventType.ThrowError -> {
StateResolution(Error(signUpEvent.exception))
StateResolution(Error(signUpEvent.signUpData, signUpEvent.exception))
}
else -> defaultResolution
}
Expand All @@ -144,3 +148,12 @@ internal sealed class SignUpState : State {
}
}
}

internal fun SignUpState.getSignUpData(): SignUpData? = when (this) {
is SignUpState.AwaitingUserConfirmation -> this.signUpData
is SignUpState.ConfirmingSignUp -> this.signUpData
is SignUpState.Error -> this.signUpData
is SignUpState.InitiatingSignUp -> this.signUpData
is SignUpState.NotStarted -> null
is SignUpState.SignedUp -> this.signUpData
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,24 @@ import com.amplifyframework.auth.cognito.testUtil.withSignUpEvent
import com.amplifyframework.auth.result.AuthSignUpResult
import com.amplifyframework.auth.result.step.AuthNextSignUpStep
import com.amplifyframework.auth.result.step.AuthSignUpStep
import com.amplifyframework.statemachine.codegen.data.SignUpData
import com.amplifyframework.statemachine.codegen.events.SignUpEvent
import com.amplifyframework.statemachine.codegen.states.AuthState
import com.amplifyframework.statemachine.codegen.states.AuthenticationState
import com.amplifyframework.statemachine.codegen.states.SignUpState
import io.kotest.assertions.throwables.shouldThrowAny
import io.kotest.matchers.nulls.shouldBeNull
import io.kotest.matchers.shouldBe
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.justRun
import io.mockk.mockk
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.runCurrent
import kotlinx.coroutines.test.runTest
import org.junit.Test
Expand All @@ -50,6 +54,14 @@ class ConfirmSignUpUseCaseTest {
}
private val useCase = ConfirmSignUpUseCase(stateMachine = stateMachine)

private val signUpData = SignUpData(
username = "username",
validationData = mapOf("key" to "value"),
clientMetadata = mapOf("meta" to "data"),
session = "session",
userId = "userId"
)

@Test
fun `fails if not configured`() = runTest {
val expectedAuthError = InvalidUserPoolConfigurationException()
Expand All @@ -75,30 +87,96 @@ class ConfirmSignUpUseCaseTest {

runCurrent()
stateFlow.emit(mockAuthState(SignUpState.ConfirmingSignUp(mockk())))
stateFlow.emit(mockAuthState(SignUpState.Error(exception)))
stateFlow.emit(mockAuthState(SignUpState.Error(mockk(), exception)))
}

@Test
fun `sends expected event`() = runTest {
coEvery { stateMachine.getCurrentState().authNState } returns AuthenticationState.Configured()
coEvery { stateMachine.getCurrentState().authSignUpState } returns null

launch {
useCase.execute("user", "pass")
executeUseCaseToCompletion()

coVerify {
stateMachine.send(
withSignUpEvent<SignUpEvent.EventType.ConfirmSignUp> { event ->
event.signUpData.username shouldBe "user"
event.confirmationCode shouldBe "pass"
}
)
}
}

coVerify {
stateMachine.send(
withSignUpEvent<SignUpEvent.EventType.ConfirmSignUp> { event ->
event.signUpData.username shouldBe "user"
event.confirmationCode shouldBe "pass"
@Test
fun `uses signUpData values from existing error state`() = runTest {
coEvery { stateMachine.getCurrentState().authNState } returns AuthenticationState.Configured()
coEvery { stateMachine.getCurrentState().authSignUpState } returns
SignUpState.Error(signUpData = signUpData, exception = mockk())

executeUseCaseToCompletion(username = signUpData.username)

coVerify {
stateMachine.send(
withSignUpEvent<SignUpEvent.EventType.ConfirmSignUp> { event ->
event.signUpData.run {
username shouldBe signUpData.username
validationData shouldBe signUpData.validationData
clientMetadata.shouldBeNull() // was not passed in options
session shouldBe signUpData.session
userId shouldBe signUpData.userId
}
)
}
event.confirmationCode shouldBe "pass"
}
)
}
}

runCurrent()
stateFlow.emit(mockAuthState(SignUpState.ConfirmingSignUp(mockk())))
stateFlow.emit(mockAuthState(SignUpState.SignedUp(mockk(), mockk())))
@Test
fun `uses session value from existing AwaitingConfirmation state`() = runTest {
coEvery { stateMachine.getCurrentState().authNState } returns AuthenticationState.Configured()
coEvery { stateMachine.getCurrentState().authSignUpState } returns
SignUpState.AwaitingUserConfirmation(signUpData = signUpData, signUpResult = mockk())

executeUseCaseToCompletion(username = signUpData.username)

coVerify {
stateMachine.send(
withSignUpEvent<SignUpEvent.EventType.ConfirmSignUp> { event ->
event.signUpData.run {
username shouldBe signUpData.username
validationData shouldBe signUpData.validationData
clientMetadata.shouldBeNull() // was not passed in options
session shouldBe signUpData.session
userId shouldBe signUpData.userId
}
event.confirmationCode shouldBe "pass"
}
)
}
}

@Test
fun `does not use values from existing state if username does not match`() = runTest {
coEvery { stateMachine.getCurrentState().authNState } returns AuthenticationState.Configured()
coEvery { stateMachine.getCurrentState().authSignUpState } returns
SignUpState.AwaitingUserConfirmation(signUpData = signUpData, signUpResult = mockk())

executeUseCaseToCompletion(username = "bob")

coVerify {
stateMachine.send(
withSignUpEvent<SignUpEvent.EventType.ConfirmSignUp> { event ->
event.signUpData.run {
username shouldBe "bob"
validationData.shouldBeNull()
clientMetadata.shouldBeNull() // was not passed in options
session.shouldBeNull()
userId.shouldBeNull()
}
event.confirmationCode shouldBe "pass"
}
)
}
}

@Test
Expand All @@ -116,17 +194,24 @@ class ConfirmSignUpUseCaseTest {
coEvery { stateMachine.getCurrentState().authNState } returns AuthenticationState.Configured()
coEvery { stateMachine.getCurrentState().authSignUpState } returns null

launch {
val result = useCase.execute("user", "pass")
result shouldBe expectedResult
}

runCurrent()
stateFlow.emit(mockAuthState(SignUpState.ConfirmingSignUp(mockk())))
stateFlow.emit(mockAuthState(SignUpState.SignedUp(mockk(), expectedResult)))
val result = executeUseCaseToCompletion(signUpResult = expectedResult)
result shouldBe expectedResult
}

private fun mockAuthState(signUpState: SignUpState): AuthState = mockk {
coEvery { authSignUpState } returns signUpState
}

@Suppress("SuspendFunctionOnCoroutineScope")
private suspend fun TestScope.executeUseCaseToCompletion(
username: String = "user",
confirmationCode: String = "pass",
signUpResult: AuthSignUpResult = mockk()
): AuthSignUpResult {
val result = async { useCase.execute(username, confirmationCode) }
runCurrent()
stateFlow.emit(mockAuthState(SignUpState.ConfirmingSignUp(mockk())))
stateFlow.emit(mockAuthState(SignUpState.SignedUp(mockk(), signUpResult)))
return result.await()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class SignUpUseCaseTest {
}

runCurrent()
stateFlow.emit(mockAuthState(SignUpState.Error(exception)))
stateFlow.emit(mockAuthState(SignUpState.Error(mockk(), exception)))
}

@Test
Expand Down