@@ -15,6 +15,7 @@ import okhttp3.WebSocketListener
1515import okhttp3.internal.concurrent.TaskRunner
1616import okhttp3.internal.ws.RealWebSocket
1717import java.util.*
18+ import android.util.Log
1819import kotlin.coroutines.CoroutineContext
1920
2021class Realtime(client: Client) : Service(client), CoroutineScope {
@@ -31,17 +32,24 @@ class Realtime(client: Client) : Service(client), CoroutineScope {
3132 private const val DEBOUNCE_MILLIS = 1L
3233
3334 private var socket: RealWebSocket? = null
34- private var channelCallbacks = mutableMapOf <String , MutableCollection < RealtimeCallback > >()
35- private var errorCallbacks = mutableSetOf < ({{ spec . title | caseUcfirst }}Exception) -> Unit >()
35+ private var activeChannels = mutableSetOf <String >()
36+ private var activeSubscriptions = mutableMapOf< Int , RealtimeCallback >()
3637
3738 private var subCallDepth = 0
39+ private var reconnectAttempts = 0
40+ private var subscriptionsCounter = 0
41+ private var reconnect = true
3842 }
3943
4044 private fun createSocket() {
45+ if (activeChannels.isEmpty()) {
46+ return
47+ }
48+
4149 val queryParamBuilder = StringBuilder()
4250 .append("project=${client.config["project"]}")
4351
44- channelCallbacks.keys .forEach {
52+ activeChannels .forEach {
4553 queryParamBuilder
4654 .append("& channels[]=$it")
4755 }
@@ -51,6 +59,7 @@ class Realtime(client: Client) : Service(client), CoroutineScope {
5159 .build()
5260
5361 if (socket != null) {
62+ reconnect = false
5463 closeSocket()
5564 }
5665
@@ -71,6 +80,13 @@ class Realtime(client: Client) : Service(client), CoroutineScope {
7180 socket?.close(RealtimeCode.POLICY_VIOLATION.value, null)
7281 }
7382
83+ private fun getTimeout() = when {
84+ reconnectAttempts < 5 -> 1000L
85+ reconnectAttempts < 15 -> 5000L
86+ reconnectAttempts < 100 -> 10000L
87+ else -> 60000L
88+ }
89+
7490 fun subscribe(
7591 vararg channels: String,
7692 callback: (RealtimeResponseEvent<Any >) -> Unit,
@@ -85,20 +101,14 @@ class Realtime(client: Client) : Service(client), CoroutineScope {
85101 payloadType: Class<T >,
86102 callback: (RealtimeResponseEvent<T >) -> Unit,
87103 ): RealtimeSubscription {
88- channels.forEach {
89- if (!channelCallbacks.containsKey(it)) {
90- channelCallbacks[it] = mutableListOf(
91- RealtimeCallback(
92- payloadType,
93- callback as (RealtimeResponseEvent< *>) -> Unit
94- )
95- )
96- return@forEach
97- }
98- channelCallbacks[it]?.add(
99- RealtimeCallback(payloadType, callback as (RealtimeResponseEvent< *>) -> Unit)
100- )
101- }
104+ val counter = subscriptionsCounter++
105+
106+ activeChannels.addAll(channels)
107+ activeSubscriptions[counter] = RealtimeCallback(
108+ channels.toList(),
109+ payloadType,
110+ callback as (RealtimeResponseEvent< *>) -> Unit
111+ )
102112
103113 launch {
104114 subCallDepth++
@@ -109,25 +119,31 @@ class Realtime(client: Client) : Service(client), CoroutineScope {
109119 subCallDepth--
110120 }
111121
112- return RealtimeSubscription { unsubscribe(*channels) }
113- }
114-
115- fun unsubscribe(vararg channels: String) {
116- channels.forEach {
117- channelCallbacks[it] = mutableListOf()
118- }
119- if (channelCallbacks.all { it.value.isEmpty() }) {
120- errorCallbacks = mutableSetOf()
121- closeSocket()
122+ return RealtimeSubscription {
123+ activeSubscriptions.remove(counter)
124+ cleanUp(*channels)
125+ createSocket()
122126 }
123127 }
124128
125- fun doOnError(callback: ({{ spec .title | caseUcfirst }}Exception) -> Unit) {
126- errorCallbacks.add(callback)
129+ private fun cleanUp(vararg channels: String) {
130+ activeChannels.removeAll { channel ->
131+ if (!channels.contains(channel)) {
132+ return@removeAll false
133+ }
134+ activeSubscriptions.values.none { callback ->
135+ callback.channels.contains(channel)
136+ }
137+ }
127138 }
128139
129140 private inner class {{ spec .title | caseUcfirst }}WebSocketListener : WebSocketListener() {
130141
142+ override fun onOpen(webSocket: WebSocket, response: Response) {
143+ super.onOpen(webSocket, response)
144+ reconnectAttempts = 0
145+ }
146+
131147 override fun onMessage(webSocket: WebSocket, text: String) {
132148 super.onMessage(webSocket, text)
133149
@@ -141,27 +157,43 @@ class Realtime(client: Client) : Service(client), CoroutineScope {
141157 }
142158
143159 private fun handleResponseError(message: RealtimeResponse) {
144- val error = message.data.jsonCast< {{ spec .title | caseUcfirst }}Exception>()
145- errorCallbacks.forEach { it.invoke(error) }
160+ throw message.data.jsonCast< {{ spec .title | caseUcfirst }}Exception>()
146161 }
147162
148163 private suspend fun handleResponseEvent(message: RealtimeResponse) {
149164 val event = message.data.jsonCast<RealtimeResponseEvent <Any >>()
150- event.channels.forEachAsync { channel ->
151- channelCallbacks[channel]?.forEachAsync { callbackWrapper ->
152- event.payload = event.payload.jsonCast(callbackWrapper.payloadClass)
153- callbackWrapper.callback.invoke(event)
165+ if (event.channels.isEmpty()) {
166+ return
167+ }
168+ if (!event.channels.any { activeChannels.contains(it) }) {
169+ return
170+ }
171+ activeSubscriptions.values.forEachAsync { subscription ->
172+ if (event.channels.any { subscription.channels.contains(it) }) {
173+ event.payload = event.payload.jsonCast(subscription.payloadClass)
174+ subscription.callback(event)
154175 }
155176 }
156177 }
157178
158179 override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
159180 super.onClosing(webSocket, code, reason)
160- if (code == RealtimeCode.POLICY_VIOLATION.value) {
181+ if (!reconnect || code == RealtimeCode.POLICY_VIOLATION.value) {
182+ reconnect = true
161183 return
162184 }
185+
186+ val timeout = getTimeout()
187+
188+ Log.e(
189+ this@Realtime::class.java.name,
190+ "Realtime disconnected. Re-connecting in ${timeout / 1000} seconds.",
191+ {{ spec .title | caseUcfirst }}Exception(reason, code)
192+ )
193+
163194 launch {
164- delay(1000)
195+ delay(timeout)
196+ reconnectAttempts++
165197 createSocket()
166198 }
167199 }
0 commit comments