@@ -36,8 +36,16 @@ import io.getstream.android.core.internal.model.events.StreamHealthCheckEvent
3636import io.getstream.android.core.internal.serialization.StreamCompositeEventSerializationImpl
3737import io.getstream.android.core.internal.serialization.StreamCompositeSerializationEvent
3838import io.getstream.android.core.internal.socket.model.ConnectUserData
39- import io.mockk.*
39+ import io.mockk.MockKAnnotations
40+ import io.mockk.Runs
41+ import io.mockk.every
42+ import io.mockk.just
43+ import io.mockk.mockk
44+ import io.mockk.slot
45+ import io.mockk.verify
46+ import java.io.IOException
4047import junit.framework.Assert.assertEquals
48+ import kotlin.time.Duration.Companion.seconds
4149import kotlinx.coroutines.async
4250import kotlinx.coroutines.cancelAndJoin
4351import kotlinx.coroutines.test.advanceUntilIdle
@@ -1208,6 +1216,88 @@ class StreamSocketSessionTest {
12081216 job.cancel()
12091217 }
12101218
1219+ @Test
1220+ fun `handshake onFailure triggers connect failure` () =
1221+ runTest(timeout = 1 .seconds) {
1222+ testHandshakeFailure { hsListener, _ ->
1223+ val socketFailure = RuntimeException (" WebSocket connection failed" )
1224+ val mockResponse = mockk<Response >(relaxed = true )
1225+ hsListener.onFailure(socketFailure, mockResponse)
1226+ }
1227+ }
1228+
1229+ @Test
1230+ fun `handshake onClosed triggers connect failure with IOException` () =
1231+ runTest(timeout = 1 .seconds) {
1232+ val result = testHandshakeFailure { hsListener, _ ->
1233+ hsListener.onClosed(12345 , " Closed because yes" )
1234+ }
1235+
1236+ // Verify that the failure contains an IOException with the expected message
1237+ val exception = result.exceptionOrNull()
1238+ assertTrue(
1239+ " Expected IOException but got ${exception?.javaClass?.simpleName} " ,
1240+ exception is IOException ,
1241+ )
1242+ }
1243+
1244+ private suspend fun testHandshakeFailure (
1245+ triggerFailure : (StreamWebSocketListener , StreamSubscription ) -> Unit
1246+ ): Result <StreamConnectionState .Connected > {
1247+ val seenStates = mutableListOf<StreamConnectionState >()
1248+ every { subs.forEach(any()) } answers
1249+ {
1250+ val consumer = arg< (StreamClientListener ) -> Unit > (0 )
1251+ val listener =
1252+ object : StreamClientListener {
1253+ override fun onState (state : StreamConnectionState ) {
1254+ seenStates + = state
1255+ }
1256+
1257+ override fun onEvent (event : Any ) {}
1258+ }
1259+ consumer(listener)
1260+ Result .success(Unit )
1261+ }
1262+
1263+ val lifeSub = mockk<StreamSubscription >(relaxed = true )
1264+ val hsSub = mockk<StreamSubscription >(relaxed = true )
1265+
1266+ every { socket.subscribe(any<StreamWebSocketListener >()) } returns Result .success(lifeSub)
1267+
1268+ var hsListener: StreamWebSocketListener ? = null
1269+ every { socket.subscribe(any<StreamWebSocketListener >(), any()) } answers
1270+ {
1271+ hsListener = firstArg()
1272+ Result .success(hsSub)
1273+ }
1274+
1275+ every { socket.open(config) } answers
1276+ {
1277+ val listener = hsListener ? : error(" Handshake listener not installed" )
1278+ triggerFailure(listener, hsSub)
1279+ Result .success(Unit )
1280+ }
1281+
1282+ every { socket.close() } returns Result .success(Unit )
1283+
1284+ val result = session.connect(connectUserData())
1285+
1286+ assertTrue(result.isFailure)
1287+
1288+ // Verify that the handshake subscription was cancelled
1289+ verify { hsSub.cancel() }
1290+
1291+ // Verify that proper connection states were emitted
1292+ assertTrue(seenStates.first() is StreamConnectionState .Connecting .Opening )
1293+ assertTrue(seenStates.any { it is StreamConnectionState .Disconnected })
1294+
1295+ // Verify no health monitoring started
1296+ verify(exactly = 0 ) { health.start() }
1297+
1298+ return result
1299+ }
1300+
12111301 private fun connectUserData (): ConnectUserData =
12121302 ConnectUserData (" u" , " t" , null , null , false , null , emptyMap())
12131303}
0 commit comments