diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a393e0a..ca6dfc14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 1.0.0-BETA29 (unreleased) + +* Fix potential race condition between jobs in `connect()` and `disconnect()`. + ## 1.0.0-BETA28 * Update PowerSync SQLite core extension to 0.3.12. diff --git a/core/src/commonIntegrationTest/kotlin/com/powersync/SyncIntegrationTest.kt b/core/src/commonIntegrationTest/kotlin/com/powersync/SyncIntegrationTest.kt index 2bf2e4bb..cfefbf7b 100644 --- a/core/src/commonIntegrationTest/kotlin/com/powersync/SyncIntegrationTest.kt +++ b/core/src/commonIntegrationTest/kotlin/com/powersync/SyncIntegrationTest.kt @@ -27,9 +27,9 @@ import com.powersync.utils.JsonUtil import dev.mokkery.answering.returns import dev.mokkery.everySuspend import dev.mokkery.mock -import kotlinx.coroutines.CoroutineScope +import dev.mokkery.verify +import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import kotlinx.serialization.encodeToString @@ -38,6 +38,7 @@ import kotlin.test.AfterTest import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertFalse import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -99,8 +100,8 @@ class SyncIntegrationTest { dbFilename = "testdb", ) as PowerSyncDatabaseImpl - private fun CoroutineScope.syncStream(): SyncStream { - val client = MockSyncService.client(this, syncLines.receiveAsFlow()) + private fun syncStream(): SyncStream { + val client = MockSyncService(syncLines) return SyncStream( bucketStorage = database.bucketStorage, connector = connector, @@ -117,6 +118,68 @@ class SyncIntegrationTest { assertEquals(amount, users.size, "Expected $amount users, got $users") } + @Test + @OptIn(DelicateCoroutinesApi::class) + fun closesResponseStreamOnDatabaseClose() = + runTest { + val syncStream = syncStream() + database.connectInternal(syncStream, 1000L) + + turbineScope(timeout = 10.0.seconds) { + val turbine = database.currentStatus.asFlow().testIn(this) + turbine.waitFor { it.connected } + + database.close() + turbine.waitFor { !it.connected } + turbine.cancel() + } + + // Closing the database should have closed the channel + assertTrue { syncLines.isClosedForSend } + } + + @Test + @OptIn(DelicateCoroutinesApi::class) + fun cleansResourcesOnDisconnect() = + runTest { + val syncStream = syncStream() + database.connectInternal(syncStream, 1000L) + + turbineScope(timeout = 10.0.seconds) { + val turbine = database.currentStatus.asFlow().testIn(this) + turbine.waitFor { it.connected } + + database.disconnect() + turbine.waitFor { !it.connected } + turbine.cancel() + } + + // Disconnecting should have closed the channel + assertTrue { syncLines.isClosedForSend } + + // And called invalidateCredentials on the connector + verify { connector.invalidateCredentials() } + } + + @Test + fun cannotUpdateSchemaWhileConnected() = + runTest { + val syncStream = syncStream() + database.connectInternal(syncStream, 1000L) + + turbineScope(timeout = 10.0.seconds) { + val turbine = database.currentStatus.asFlow().testIn(this) + turbine.waitFor { it.connected } + turbine.cancel() + } + + assertFailsWith("Cannot update schema while connected") { + database.updateSchema(Schema()) + } + + database.close() + } + @Test fun testPartialSync() = runTest { diff --git a/core/src/commonMain/kotlin/com/powersync/db/PowerSyncDatabaseImpl.kt b/core/src/commonMain/kotlin/com/powersync/db/PowerSyncDatabaseImpl.kt index ccde3ae7..244d9516 100644 --- a/core/src/commonMain/kotlin/com/powersync/db/PowerSyncDatabaseImpl.kt +++ b/core/src/commonMain/kotlin/com/powersync/db/PowerSyncDatabaseImpl.kt @@ -24,10 +24,11 @@ import com.powersync.utils.JsonParam import com.powersync.utils.JsonUtil import com.powersync.utils.throttle import com.powersync.utils.toJsonObject +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.Job -import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.ensureActive import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filter @@ -95,9 +96,7 @@ internal class PowerSyncDatabaseImpl( override val currentStatus: SyncStatus = SyncStatus() private val mutex = Mutex() - private var syncStream: SyncStream? = null - private var syncJob: Job? = null - private var uploadJob: Job? = null + private var syncSupervisorJob: Job? = null // This is set in the init private lateinit var powerSyncVersion: String @@ -123,7 +122,7 @@ internal class PowerSyncDatabaseImpl( override suspend fun updateSchema(schema: Schema) = runWrappedSuspending { mutex.withLock { - if (this.syncStream != null) { + if (this.syncSupervisorJob != null) { throw PowerSyncException( "Cannot update schema while connected", cause = Exception("PowerSync client is already connected"), @@ -161,12 +160,11 @@ internal class PowerSyncDatabaseImpl( stream: SyncStream, crudThrottleMs: Long, ) { - this.syncStream = stream - val db = this - - syncJob = - scope.launch { + val job = SupervisorJob(scope.coroutineContext[Job]) + syncSupervisorJob = job + scope.launch(job) { + launch { // Get a global lock for checking mutex maps val streamMutex = resource.group.syncMutex @@ -181,7 +179,7 @@ internal class PowerSyncDatabaseImpl( // (The tryLock should throw if this client already holds the lock). logger.w(streamConflictMessage) } - } catch (ex: IllegalStateException) { + } catch (_: IllegalStateException) { logger.e { "The streaming sync client did not disconnect before connecting" } } @@ -194,40 +192,46 @@ internal class PowerSyncDatabaseImpl( // We have a lock if we reached here try { ensureActive() - syncStream!!.streamingSync() + stream.streamingSync() } finally { streamMutex.unlock(db) } } - scope.launch { - syncStream!!.status.asFlow().collect { - currentStatus.update( - connected = it.connected, - connecting = it.connecting, - uploading = it.uploading, - downloading = it.downloading, - lastSyncedAt = it.lastSyncedAt, - hasSynced = it.hasSynced, - uploadError = it.uploadError, - downloadError = it.downloadError, - clearDownloadError = it.downloadError == null, - clearUploadError = it.uploadError == null, - priorityStatusEntries = it.priorityStatusEntries, - ) + launch { + stream.status.asFlow().collect { + currentStatus.update( + connected = it.connected, + connecting = it.connecting, + uploading = it.uploading, + downloading = it.downloading, + lastSyncedAt = it.lastSyncedAt, + hasSynced = it.hasSynced, + uploadError = it.uploadError, + downloadError = it.downloadError, + clearDownloadError = it.downloadError == null, + clearUploadError = it.uploadError == null, + priorityStatusEntries = it.priorityStatusEntries, + ) + } } - } - uploadJob = - scope.launch { + launch { internalDb .updatesOnTables() .filter { it.contains(InternalTable.CRUD.toString()) } .throttle(crudThrottleMs) .collect { - syncStream!!.triggerCrudUpload() + stream.triggerCrudUpload() } } + } + + job.invokeOnCompletion { + if (it is DisconnectRequestedException) { + stream.invalidateCredentials() + } + } } override suspend fun getCrudBatch(limit: Int): CrudBatch? { @@ -364,17 +368,12 @@ internal class PowerSyncDatabaseImpl( override suspend fun disconnect() = mutex.withLock { disconnectInternal() } private suspend fun disconnectInternal() { - if (syncJob != null && syncJob!!.isActive) { - syncJob?.cancelAndJoin() - } - - if (uploadJob != null && uploadJob!!.isActive) { - uploadJob?.cancelAndJoin() - } - - if (syncStream != null) { - syncStream?.invalidateCredentials() - syncStream = null + val syncJob = syncSupervisorJob + if (syncJob != null && syncJob.isActive) { + // Using this exception type will also make the sync job invalidate credentials. + syncJob.cancel(DisconnectRequestedException) + syncJob.join() + syncSupervisorJob = null } currentStatus.update( @@ -470,7 +469,7 @@ internal class PowerSyncDatabaseImpl( /** * Check that a supported version of the powersync extension is loaded. */ - private suspend fun checkVersion(powerSyncVersion: String) { + private fun checkVersion(powerSyncVersion: String) { // Parse version val versionInts: List = try { @@ -488,3 +487,5 @@ internal class PowerSyncDatabaseImpl( } } } + +internal object DisconnectRequestedException : CancellationException("disconnect() called") diff --git a/core/src/commonTest/kotlin/com/powersync/sync/SyncStreamTest.kt b/core/src/commonTest/kotlin/com/powersync/sync/SyncStreamTest.kt index 30588d47..41573e3f 100644 --- a/core/src/commonTest/kotlin/com/powersync/sync/SyncStreamTest.kt +++ b/core/src/commonTest/kotlin/com/powersync/sync/SyncStreamTest.kt @@ -30,7 +30,6 @@ import dev.mokkery.verifySuspend import io.ktor.client.engine.mock.MockEngine import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withTimeout @@ -210,7 +209,7 @@ class SyncStreamTest { // TODO: It would be neat if we could use in-memory sqlite instances instead of mocking everything // Revisit https://github.com/powersync-ja/powersync-kotlin/pull/117/files at some point val syncLines = Channel() - val client = MockSyncService.client(this, syncLines.receiveAsFlow()) + val client = MockSyncService(syncLines) syncStream = SyncStream( diff --git a/core/src/commonTest/kotlin/com/powersync/testutils/MockSyncService.kt b/core/src/commonTest/kotlin/com/powersync/testutils/MockSyncService.kt index eb57a232..56ab90ec 100644 --- a/core/src/commonTest/kotlin/com/powersync/testutils/MockSyncService.kt +++ b/core/src/commonTest/kotlin/com/powersync/testutils/MockSyncService.kt @@ -4,53 +4,85 @@ import app.cash.turbine.ReceiveTurbine import com.powersync.sync.SyncLine import com.powersync.sync.SyncStatusData import com.powersync.utils.JsonUtil -import io.ktor.client.engine.HttpClientEngine -import io.ktor.client.engine.mock.MockEngine -import io.ktor.client.engine.mock.MockRequestHandleScope -import io.ktor.client.engine.mock.respond -import io.ktor.client.engine.mock.respondBadRequest +import io.ktor.client.engine.HttpClientEngineBase +import io.ktor.client.engine.HttpClientEngineCapability +import io.ktor.client.engine.HttpClientEngineConfig +import io.ktor.client.engine.callContext +import io.ktor.client.plugins.HttpTimeoutCapability import io.ktor.client.request.HttpRequestData import io.ktor.client.request.HttpResponseData -import io.ktor.utils.io.ByteChannel +import io.ktor.http.HttpProtocolVersion +import io.ktor.http.HttpStatusCode +import io.ktor.http.headersOf +import io.ktor.util.date.GMTDate +import io.ktor.utils.io.InternalAPI +import io.ktor.utils.io.awaitFreeSpace import io.ktor.utils.io.writeStringUtf8 +import io.ktor.utils.io.writer import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.launch +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.consume import kotlinx.serialization.encodeToString -internal class MockSyncService private constructor( - private val scope: CoroutineScope, - private val lines: Flow, -) { - private fun handleRequest( - scope: MockRequestHandleScope, - request: HttpRequestData, - ): HttpResponseData = - if (request.url.encodedPath == "/sync/stream") { - val channel = ByteChannel(autoFlush = true) - this.scope.launch { - lines.collect { - val serializedLine = JsonUtil.json.encodeToString(it) - channel.writeStringUtf8("$serializedLine\n") +/** + * A mock HTTP engine providing sync lines read from a coroutines [ReceiveChannel]. + * + * Note that we can't trivially use ktor's `MockEngine` here because that engine requires a non-suspending handler + * function which makes it very hard to cancel the channel when the sync client closes the request stream. That is + * precisely what we may want to test though. + */ +internal class MockSyncService( + private val lines: ReceiveChannel, +) : HttpClientEngineBase("sync-service") { + override val config: HttpClientEngineConfig + get() = Config + + override val supportedCapabilities: Set> = + setOf( + HttpTimeoutCapability, + ) + + @OptIn(InternalAPI::class) + override suspend fun execute(data: HttpRequestData): HttpResponseData { + val context = callContext() + val scope = CoroutineScope(context) + + return if (data.url.encodedPath == "/sync/stream") { + val job = + scope.writer { + lines.consume { + while (true) { + // Wait for a downstream listener being ready before requesting a sync line + channel.awaitFreeSpace() + val line = receive() + val serializedLine = JsonUtil.json.encodeToString(line) + channel.writeStringUtf8("$serializedLine\n") + channel.flush() + } + } } - } - scope.respond(channel) + HttpResponseData( + HttpStatusCode.OK, + GMTDate(), + headersOf(), + HttpProtocolVersion.HTTP_1_1, + job.channel, + context, + ) } else { - scope.respondBadRequest() - } - - companion object { - fun client( - scope: CoroutineScope, - lines: Flow, - ): HttpClientEngine { - val service = MockSyncService(scope, lines) - return MockEngine { request -> - service.handleRequest(this, request) - } + HttpResponseData( + HttpStatusCode.BadRequest, + GMTDate(), + headersOf(), + HttpProtocolVersion.HTTP_1_1, + "", + context, + ) } } + + private object Config : HttpClientEngineConfig() } suspend inline fun ReceiveTurbine.waitFor(matcher: (SyncStatusData) -> Boolean) {