diff --git a/stream-android-core/src/main/java/io/getstream/android/core/internal/socket/StreamSocketSession.kt b/stream-android-core/src/main/java/io/getstream/android/core/internal/socket/StreamSocketSession.kt index 11c1f78..3811c5a 100644 --- a/stream-android-core/src/main/java/io/getstream/android/core/internal/socket/StreamSocketSession.kt +++ b/stream-android-core/src/main/java/io/getstream/android/core/internal/socket/StreamSocketSession.kt @@ -368,6 +368,16 @@ internal class StreamSocketSession( failure(it) } } + + override fun onFailure(t: Throwable, response: Response?) { + logger.e(t) { "[onFailure] Socket failure. ${t.message}" } + failure(t) + } + + override fun onClosed(code: Int, reason: String) { + logger.e { "[onClosed] Socket closed. Code: $code, Reason: $reason" } + failure(IOException("Socket closed. Code: $code, Reason: $reason")) + } } val hsRes = diff --git a/stream-android-core/src/test/java/io/getstream/android/core/internal/socket/StreamSocketSessionTest.kt b/stream-android-core/src/test/java/io/getstream/android/core/internal/socket/StreamSocketSessionTest.kt index d9cc7cc..146c6fe 100644 --- a/stream-android-core/src/test/java/io/getstream/android/core/internal/socket/StreamSocketSessionTest.kt +++ b/stream-android-core/src/test/java/io/getstream/android/core/internal/socket/StreamSocketSessionTest.kt @@ -36,8 +36,16 @@ import io.getstream.android.core.internal.model.events.StreamHealthCheckEvent import io.getstream.android.core.internal.serialization.StreamCompositeEventSerializationImpl import io.getstream.android.core.internal.serialization.StreamCompositeSerializationEvent import io.getstream.android.core.internal.socket.model.ConnectUserData -import io.mockk.* +import io.mockk.MockKAnnotations +import io.mockk.Runs +import io.mockk.every +import io.mockk.just +import io.mockk.mockk +import io.mockk.slot +import io.mockk.verify +import java.io.IOException import junit.framework.Assert.assertEquals +import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.async import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.test.advanceUntilIdle @@ -1208,6 +1216,88 @@ class StreamSocketSessionTest { job.cancel() } + @Test + fun `handshake onFailure triggers connect failure`() = + runTest(timeout = 1.seconds) { + testHandshakeFailure { hsListener, _ -> + val socketFailure = RuntimeException("WebSocket connection failed") + val mockResponse = mockk(relaxed = true) + hsListener.onFailure(socketFailure, mockResponse) + } + } + + @Test + fun `handshake onClosed triggers connect failure with IOException`() = + runTest(timeout = 1.seconds) { + val result = testHandshakeFailure { hsListener, _ -> + hsListener.onClosed(12345, "Closed because yes") + } + + // Verify that the failure contains an IOException with the expected message + val exception = result.exceptionOrNull() + assertTrue( + "Expected IOException but got ${exception?.javaClass?.simpleName}", + exception is IOException, + ) + } + + private suspend fun testHandshakeFailure( + triggerFailure: (StreamWebSocketListener, StreamSubscription) -> Unit + ): Result { + val seenStates = mutableListOf() + every { subs.forEach(any()) } answers + { + val consumer = arg<(StreamClientListener) -> Unit>(0) + val listener = + object : StreamClientListener { + override fun onState(state: StreamConnectionState) { + seenStates += state + } + + override fun onEvent(event: Any) {} + } + consumer(listener) + Result.success(Unit) + } + + val lifeSub = mockk(relaxed = true) + val hsSub = mockk(relaxed = true) + + every { socket.subscribe(any()) } returns Result.success(lifeSub) + + var hsListener: StreamWebSocketListener? = null + every { socket.subscribe(any(), any()) } answers + { + hsListener = firstArg() + Result.success(hsSub) + } + + every { socket.open(config) } answers + { + val listener = hsListener ?: error("Handshake listener not installed") + triggerFailure(listener, hsSub) + Result.success(Unit) + } + + every { socket.close() } returns Result.success(Unit) + + val result = session.connect(connectUserData()) + + assertTrue(result.isFailure) + + // Verify that the handshake subscription was cancelled + verify { hsSub.cancel() } + + // Verify that proper connection states were emitted + assertTrue(seenStates.first() is StreamConnectionState.Connecting.Opening) + assertTrue(seenStates.any { it is StreamConnectionState.Disconnected }) + + // Verify no health monitoring started + verify(exactly = 0) { health.start() } + + return result + } + private fun connectUserData(): ConnectUserData = ConnectUserData("u", "t", null, null, false, null, emptyMap()) }