Skip to content

Commit d1945d7

Browse files
committed
Refactor push message handling, synchronization and coroutines
1 parent 5b68505 commit d1945d7

File tree

10 files changed

+197
-180
lines changed

10 files changed

+197
-180
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright © All Contributors. See LICENSE and AUTHORS in the root directory for details.
3+
*/
4+
5+
package at.bitfire.davdroid.push
6+
7+
import dagger.hilt.android.testing.HiltAndroidRule
8+
import dagger.hilt.android.testing.HiltAndroidTest
9+
import org.junit.Assert
10+
import org.junit.Before
11+
import org.junit.Rule
12+
import org.junit.Test
13+
import javax.inject.Inject
14+
15+
@HiltAndroidTest
16+
class PushMessageHandlerTest {
17+
18+
@get:Rule
19+
val hiltRule = HiltAndroidRule(this)
20+
21+
@Inject
22+
lateinit var handler: PushMessageHandler
23+
24+
@Before
25+
fun setUp() {
26+
hiltRule.inject()
27+
}
28+
29+
30+
@Test
31+
fun testParse_InvalidXml() {
32+
Assert.assertNull(handler.parse("Non-XML content"))
33+
}
34+
35+
@Test
36+
fun testParse_WithXmlDeclAndTopic() {
37+
val topic = handler.parse(
38+
"<?xml version=\"1.0\" encoding=\"utf-8\" ?>" +
39+
"<P:push-message xmlns:D=\"DAV:\" xmlns:P=\"https://bitfire.at/webdav-push\">" +
40+
" <P:topic>O7M1nQ7cKkKTKsoS_j6Z3w</P:topic>" +
41+
"</P:push-message>"
42+
)
43+
Assert.assertEquals("O7M1nQ7cKkKTKsoS_j6Z3w", topic)
44+
}
45+
46+
}

app/src/main/kotlin/at/bitfire/davdroid/db/CollectionDao.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ interface CollectionDao {
3333
fun getByServiceAndType(serviceId: Long, @CollectionType type: String): List<Collection>
3434

3535
@Query("SELECT * FROM collection WHERE pushTopic=:topic AND sync")
36-
fun getSyncableByPushTopic(topic: String): Collection?
36+
suspend fun getSyncableByPushTopic(topic: String): Collection?
3737

3838
@Query("SELECT pushVapidKey FROM collection WHERE serviceId=:serviceId AND pushVapidKey IS NOT NULL LIMIT 1")
3939
suspend fun getFirstVapidKey(serviceId: Long): String?

app/src/main/kotlin/at/bitfire/davdroid/db/ServiceDao.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ interface ServiceDao {
2525
@Query("SELECT * FROM service WHERE id=:id")
2626
fun get(id: Long): Service?
2727

28+
@Query("SELECT * FROM service WHERE id=:id")
29+
suspend fun getAsync(id: Long): Service?
30+
2831
@Query("SELECT * FROM service")
2932
suspend fun getAll(): List<Service>
3033

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright © All Contributors. See LICENSE and AUTHORS in the root directory for details.
3+
*/
4+
5+
package at.bitfire.davdroid.push
6+
7+
import androidx.annotation.VisibleForTesting
8+
import at.bitfire.dav4jvm.XmlReader
9+
import at.bitfire.dav4jvm.XmlUtils
10+
import at.bitfire.davdroid.db.Collection.Companion.TYPE_ADDRESSBOOK
11+
import at.bitfire.davdroid.repository.AccountRepository
12+
import at.bitfire.davdroid.repository.DavCollectionRepository
13+
import at.bitfire.davdroid.repository.DavServiceRepository
14+
import at.bitfire.davdroid.sync.SyncDataType
15+
import at.bitfire.davdroid.sync.TasksAppManager
16+
import at.bitfire.davdroid.sync.worker.SyncWorkerManager
17+
import dagger.Lazy
18+
import org.unifiedpush.android.connector.data.PushMessage
19+
import org.xmlpull.v1.XmlPullParserException
20+
import java.io.StringReader
21+
import java.util.logging.Level
22+
import java.util.logging.Logger
23+
import javax.inject.Inject
24+
import at.bitfire.dav4jvm.property.push.PushMessage as DavPushMessage
25+
26+
/**
27+
* Handles incoming WebDAV-Push messages.
28+
*/
29+
class PushMessageHandler @Inject constructor(
30+
private val accountRepository: AccountRepository,
31+
private val collectionRepository: DavCollectionRepository,
32+
private val logger: Logger,
33+
private val serviceRepository: DavServiceRepository,
34+
private val syncWorkerManager: SyncWorkerManager,
35+
private val tasksAppManager: Lazy<TasksAppManager>
36+
) {
37+
38+
suspend fun processMessage(message: PushMessage, instance: String) {
39+
if (!message.decrypted) {
40+
logger.severe("Received a push message that could not be decrypted.")
41+
return
42+
}
43+
val messageXml = message.content.toString(Charsets.UTF_8)
44+
logger.log(Level.INFO, "Received push message", messageXml)
45+
46+
// parse push notification
47+
val topic = parse(messageXml)
48+
49+
// sync affected collection
50+
if (topic != null) {
51+
logger.info("Got push notification for topic $topic")
52+
53+
// Sync all authorities of account that the collection belongs to
54+
// Later: only sync affected collection and authorities
55+
collectionRepository.getSyncableByTopic(topic)?.let { collection ->
56+
serviceRepository.getAsync(collection.serviceId)?.let { service ->
57+
val syncDataTypes = mutableSetOf<SyncDataType>()
58+
// If the type is an address book, add the contacts type
59+
if (collection.type == TYPE_ADDRESSBOOK)
60+
syncDataTypes += SyncDataType.CONTACTS
61+
62+
// If the collection supports events, add the events type
63+
if (collection.supportsVEVENT != false)
64+
syncDataTypes += SyncDataType.EVENTS
65+
66+
// If the collection supports tasks, make sure there's a provider installed,
67+
// and add the tasks type
68+
if (collection.supportsVJOURNAL != false || collection.supportsVTODO != false)
69+
if (tasksAppManager.get().currentProvider() != null)
70+
syncDataTypes += SyncDataType.TASKS
71+
72+
// Schedule sync for all the types identified
73+
val account = accountRepository.fromName(service.accountName)
74+
for (syncDataType in syncDataTypes)
75+
syncWorkerManager.enqueueOneTime(account, syncDataType, fromPush = true)
76+
}
77+
}
78+
79+
} else {
80+
// fallback when no known topic is present (shouldn't happen)
81+
val service = instance.toLongOrNull()?.let { serviceRepository.get(it) }
82+
if (service != null) {
83+
logger.warning("Got push message without topic and service, syncing all accounts")
84+
val account = accountRepository.fromName(service.accountName)
85+
syncWorkerManager.enqueueOneTimeAllAuthorities(account, fromPush = true)
86+
87+
} else {
88+
logger.warning("Got push message without topic, syncing all accounts")
89+
for (account in accountRepository.getAll())
90+
syncWorkerManager.enqueueOneTimeAllAuthorities(account, fromPush = true)
91+
}
92+
}
93+
}
94+
95+
/**
96+
* Parses a WebDAV-Push message and returns the `topic` that the message is about.
97+
*
98+
* @return topic of the modified collection, or `null` if the topic couldn't be determined
99+
*/
100+
@VisibleForTesting
101+
internal fun parse(message: String): String? {
102+
var topic: String? = null
103+
104+
val parser = XmlUtils.newPullParser()
105+
try {
106+
parser.setInput(StringReader(message))
107+
108+
XmlReader(parser).processTag(DavPushMessage.NAME) {
109+
val pushMessage = DavPushMessage.Factory.create(parser)
110+
topic = pushMessage.topic?.topic
111+
}
112+
} catch (e: XmlPullParserException) {
113+
logger.log(Level.WARNING, "Couldn't parse push message", e)
114+
}
115+
116+
return topic
117+
}
118+
119+
}

app/src/main/kotlin/at/bitfire/davdroid/push/PushMessageParser.kt

Lines changed: 0 additions & 43 deletions
This file was deleted.

app/src/main/kotlin/at/bitfire/davdroid/push/PushRegistrationManager.kt

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ import at.bitfire.dav4jvm.property.push.WebPushSubscription
2626
import at.bitfire.davdroid.db.Collection
2727
import at.bitfire.davdroid.db.Service
2828
import at.bitfire.davdroid.network.HttpClient
29+
import at.bitfire.davdroid.push.PushRegistrationManager.Companion.mutex
2930
import at.bitfire.davdroid.repository.AccountRepository
3031
import at.bitfire.davdroid.repository.DavCollectionRepository
3132
import at.bitfire.davdroid.repository.DavServiceRepository
3233
import dagger.Lazy
3334
import dagger.hilt.android.qualifiers.ApplicationContext
34-
import kotlinx.coroutines.Dispatchers
3535
import kotlinx.coroutines.runInterruptible
36-
import kotlinx.coroutines.withContext
36+
import kotlinx.coroutines.sync.Mutex
37+
import kotlinx.coroutines.sync.withLock
3738
import okhttp3.HttpUrl
3839
import okhttp3.HttpUrl.Companion.toHttpUrlOrNull
3940
import okhttp3.RequestBody.Companion.toRequestBody
@@ -52,6 +53,8 @@ import javax.inject.Provider
5253
* Manages push registrations and subscriptions.
5354
*
5455
* To update push registrations and subscriptions (for instance after collections have been changed), call [update].
56+
*
57+
* Public API calls are protected by [mutex] so that there won't be multiple subscribe/unsubscribe operations at the same time.
5558
*/
5659
class PushRegistrationManager @Inject constructor(
5760
private val accountRepository: Lazy<AccountRepository>,
@@ -68,7 +71,7 @@ class PushRegistrationManager @Inject constructor(
6871
*
6972
* Also makes sure that the [PushRegistrationWorker] is enabled if there's a Push-enabled collection.
7073
*/
71-
suspend fun update() = withContext(dispatcher) {
74+
suspend fun update() = mutex.withLock {
7275
for (service in serviceRepository.getAll())
7376
updateService(service.id)
7477

@@ -78,13 +81,13 @@ class PushRegistrationManager @Inject constructor(
7881
/**
7982
* Same as [update], but for a specific database service.
8083
*/
81-
suspend fun update(serviceId: Long) = withContext(dispatcher) {
84+
suspend fun update(serviceId: Long) = mutex.withLock {
8285
updateService(serviceId)
8386
updatePeriodicWorker()
8487
}
8588

8689
private suspend fun updateService(serviceId: Long) {
87-
val service = serviceRepository.get(serviceId) ?: return
90+
val service = serviceRepository.getAsync(serviceId) ?: return
8891
val vapid = collectionRepository.getVapidKey(serviceId)
8992

9093
if (vapid != null)
@@ -102,12 +105,12 @@ class PushRegistrationManager @Inject constructor(
102105

103106

104107
/**
105-
* Called when a subscription (endpoint) is available for the given service.
108+
* Called by [UnifiedPushService] when a subscription (endpoint) is available for the given service.
106109
*
107110
* Uses the subscription to subscribe to syncable collections, and then unsubscribes from non-syncable collections.
108111
*/
109-
internal suspend fun processSubscription(serviceId: Long, endpoint: PushEndpoint) = withContext(dispatcher) {
110-
val service = serviceRepository.get(serviceId) ?: return@withContext
112+
internal suspend fun processSubscription(serviceId: Long, endpoint: PushEndpoint) = mutex.withLock {
113+
val service = serviceRepository.getAsync(serviceId) ?: return
111114

112115
subscribeSyncable(service, endpoint)
113116
unsubscribeNotSyncable(service)
@@ -164,11 +167,11 @@ class PushRegistrationManager @Inject constructor(
164167
*
165168
* Unsubscribes from all subscribed collections.
166169
*/
167-
internal suspend fun removeSubscription(serviceId: Long) = withContext(dispatcher) {
168-
val service = serviceRepository.get(serviceId) ?: return@withContext
170+
internal suspend fun removeSubscription(serviceId: Long) = mutex.withLock {
171+
val service = serviceRepository.getAsync(serviceId) ?: return
169172
val unsubscribeFrom = collectionRepository.getPushRegistered(service.id)
170173
if (unsubscribeFrom.isEmpty())
171-
return@withContext
174+
return
172175

173176
val account = accountRepository.get().fromName(service.accountName)
174177
httpClientBuilder.get()
@@ -272,7 +275,7 @@ class PushRegistrationManager @Inject constructor(
272275
*
273276
* Otherwise, a potentially existing worker is cancelled.
274277
*/
275-
suspend fun updatePeriodicWorker() = withContext(dispatcher) {
278+
private suspend fun updatePeriodicWorker() {
276279
val workerNeeded = collectionRepository.anyPushCapable()
277280

278281
val workManager = WorkManager.getInstance(context)
@@ -303,10 +306,7 @@ class PushRegistrationManager @Inject constructor(
303306
private const val WORKER_UNIQUE_NAME = "push-registration"
304307
const val WORKER_INTERVAL_DAYS = 1L
305308

306-
/**
307-
* Single-thread dispatcher to synchronize non-private calls.
308-
*/
309-
val dispatcher = Dispatchers.IO.limitedParallelism(1)
309+
val mutex = Mutex()
310310

311311
}
312312

0 commit comments

Comments
 (0)