Skip to content

Commit 1dadfda

Browse files
committed
[WIP] Remove PushRegistrationWorkerManager and refactor PushRegistrationManager
1 parent 2adea34 commit 1dadfda

File tree

8 files changed

+235
-345
lines changed

8 files changed

+235
-345
lines changed

app/src/androidTest/kotlin/at/bitfire/davdroid/TestModules.kt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,13 @@
44

55
package at.bitfire.davdroid
66

7-
import at.bitfire.davdroid.push.PushRegistrationWorkerManager
8-
import at.bitfire.davdroid.repository.DavCollectionRepository
97
import at.bitfire.davdroid.startup.StartupPlugin
108
import at.bitfire.davdroid.startup.TasksAppWatcher
119
import dagger.Module
1210
import dagger.hilt.components.SingletonComponent
1311
import dagger.hilt.testing.TestInstallIn
1412
import dagger.multibindings.Multibinds
1513

16-
// remove PushRegistrationWorkerModule from Android tests
17-
@Module
18-
@TestInstallIn(
19-
components = [SingletonComponent::class],
20-
replaces = [PushRegistrationWorkerManager.PushRegistrationWorkerModule::class]
21-
)
22-
abstract class TestPushRegistrationWorkerModule {
23-
// provides empty set of listeners
24-
@Multibinds
25-
abstract fun empty(): Set<DavCollectionRepository.OnChangeListener>
26-
}
27-
2814
// remove TasksAppWatcherModule from Android tests
2915
@Module
3016
@TestInstallIn(

app/src/main/kotlin/at/bitfire/davdroid/db/CollectionDao.kt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ interface CollectionDao {
2424
fun getFlow(id: Long): Flow<Collection?>
2525

2626
@Query("SELECT * FROM collection WHERE serviceId=:serviceId")
27-
fun getByService(serviceId: Long): List<Collection>
27+
suspend fun getByService(serviceId: Long): List<Collection>
2828

2929
@Query("SELECT * FROM collection WHERE serviceId=:serviceId AND homeSetId IS :homeSetId")
3030
fun getByServiceAndHomeset(serviceId: Long, homeSetId: Long?): List<Collection>
@@ -36,7 +36,7 @@ interface CollectionDao {
3636
fun getSyncableByPushTopic(topic: String): Collection?
3737

3838
@Query("SELECT pushVapidKey FROM collection WHERE serviceId=:serviceId AND pushVapidKey IS NOT NULL LIMIT 1")
39-
fun getFirstVapidKey(serviceId: Long): String?
39+
suspend fun getFirstVapidKey(serviceId: Long): String?
4040

4141
@Query("SELECT COUNT(*) FROM collection WHERE serviceId=:serviceId AND type=:type")
4242
suspend fun anyOfType(serviceId: Long, @CollectionType type: String): Boolean
@@ -78,8 +78,11 @@ interface CollectionDao {
7878
@Query("SELECT * FROM collection WHERE serviceId=:serviceId AND sync AND supportsWebPush AND pushTopic IS NOT NULL")
7979
fun getPushCapableSyncCollections(serviceId: Long): List<Collection>
8080

81-
@Query("SELECT * FROM collection WHERE pushSubscription IS NOT NULL AND NOT sync")
82-
suspend fun getPushRegisteredAndNotSyncable(): List<Collection>
81+
@Query("SELECT * FROM collection WHERE serviceId=:serviceId AND pushSubscription IS NOT NULL")
82+
suspend fun getPushRegistered(serviceId: Long): List<Collection>
83+
84+
@Query("SELECT * FROM collection WHERE serviceId=:serviceId AND pushSubscription IS NOT NULL AND NOT sync")
85+
suspend fun getPushRegisteredAndNotSyncable(serviceId: Long): List<Collection>
8386

8487
@Insert(onConflict = OnConflictStrategy.IGNORE)
8588
fun insert(collection: Collection): Long

app/src/main/kotlin/at/bitfire/davdroid/push/PushRegistrationManager.kt

Lines changed: 190 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,50 +5,82 @@
55
package at.bitfire.davdroid.push
66

77
import android.content.Context
8+
import androidx.work.BackoffPolicy
9+
import androidx.work.Constraints
10+
import androidx.work.ExistingPeriodicWorkPolicy
11+
import androidx.work.NetworkType
12+
import androidx.work.PeriodicWorkRequest
13+
import androidx.work.WorkManager
814
import at.bitfire.dav4jvm.DavCollection
915
import at.bitfire.dav4jvm.DavResource
1016
import at.bitfire.dav4jvm.HttpUtils
1117
import at.bitfire.dav4jvm.XmlUtils
1218
import at.bitfire.dav4jvm.XmlUtils.insertTag
19+
import at.bitfire.dav4jvm.exception.DavException
1320
import at.bitfire.dav4jvm.property.push.AuthSecret
1421
import at.bitfire.dav4jvm.property.push.PushRegister
1522
import at.bitfire.dav4jvm.property.push.PushResource
1623
import at.bitfire.dav4jvm.property.push.Subscription
1724
import at.bitfire.dav4jvm.property.push.SubscriptionPublicKey
1825
import at.bitfire.dav4jvm.property.push.WebPushSubscription
1926
import at.bitfire.davdroid.db.Collection
27+
import at.bitfire.davdroid.db.Service
2028
import at.bitfire.davdroid.network.HttpClient
2129
import at.bitfire.davdroid.repository.AccountRepository
2230
import at.bitfire.davdroid.repository.DavCollectionRepository
2331
import at.bitfire.davdroid.repository.DavServiceRepository
32+
import dagger.Lazy
2433
import dagger.hilt.android.qualifiers.ApplicationContext
34+
import kotlinx.coroutines.Dispatchers
35+
import kotlinx.coroutines.runBlocking
36+
import kotlinx.coroutines.runInterruptible
37+
import kotlinx.coroutines.withContext
38+
import okhttp3.HttpUrl
39+
import okhttp3.HttpUrl.Companion.toHttpUrlOrNull
2540
import okhttp3.RequestBody.Companion.toRequestBody
2641
import org.unifiedpush.android.connector.UnifiedPush
2742
import org.unifiedpush.android.connector.data.PushEndpoint
2843
import java.io.StringWriter
2944
import java.time.Duration
3045
import java.time.Instant
46+
import java.util.concurrent.TimeUnit
3147
import java.util.logging.Level
3248
import java.util.logging.Logger
3349
import javax.inject.Inject
3450
import javax.inject.Provider
3551

3652
class PushRegistrationManager @Inject constructor(
37-
private val accountRepository: AccountRepository,
53+
private val accountRepository: Lazy<AccountRepository>,
3854
private val collectionRepository: DavCollectionRepository,
3955
@ApplicationContext private val context: Context,
4056
private val httpClientBuilder: Provider<HttpClient.Builder>,
4157
private val logger: Logger,
4258
private val serviceRepository: DavServiceRepository
4359
) {
4460

45-
fun update() {
61+
/**
62+
* Updates all push registrations and subscriptions so that if Push is available, it's up-to-date and
63+
* working for all database services.
64+
*
65+
* Also makes sure that the [PushRegistrationWorker] is enabled if there's a Push-enabled collection.
66+
*/
67+
suspend fun update() = withContext(dispatcher) {
4668
for (service in serviceRepository.getAll())
47-
update(service.id)
69+
updateService(service.id)
70+
71+
updatePeriodicWorker()
4872
}
4973

50-
fun update(serviceId: Long) {
51-
val service = serviceRepository.get(serviceId) ?: return
74+
/**
75+
* Same as [update], but for a specific database service.
76+
*/
77+
suspend fun update(serviceId: Long) {
78+
updateService(serviceId)
79+
updatePeriodicWorker()
80+
}
81+
82+
private suspend fun updateService(serviceId: Long) = withContext(dispatcher) {
83+
val service = serviceRepository.get(serviceId) ?: return@withContext
5284
val vapid = collectionRepository.getVapidKey(serviceId)
5385

5486
if (vapid != null)
@@ -59,39 +91,103 @@ class PushRegistrationManager @Inject constructor(
5991
}
6092
else
6193
UnifiedPush.unregister(context, serviceId.toString())
94+
95+
// UnifiedPush has now been called. It will do its work and then call back to UnifiedPushService, which
96+
// will then call processSubscription or removeSubscription.
6297
}
6398

64-
fun registerSubscription(serviceId: Long, endpoint: PushEndpoint) {
65-
val service = serviceRepository.get(serviceId) ?: return
6699

67-
val collectionsToRegister = collectionRepository.getPushCapableAndSyncable(serviceId)
68-
if (collectionsToRegister.isEmpty())
100+
/**
101+
* Called when a subscription (endpoint) is available for the given service.
102+
*
103+
* Uses the subscription to subscribe to syncable collections, and then unsubscribes from non-syncable collections.
104+
*/
105+
internal suspend fun processSubscription(serviceId: Long, endpoint: PushEndpoint) = withContext(dispatcher) {
106+
val service = serviceRepository.get(serviceId) ?: return@withContext
107+
108+
subscribeSyncable(service, endpoint)
109+
unsubscribeNotSyncable(service)
110+
}
111+
112+
private suspend fun subscribeSyncable(service: Service, endpoint: PushEndpoint) {
113+
val subscribeTo = collectionRepository.getPushCapableAndSyncable(service.id)
114+
if (subscribeTo.isEmpty())
69115
return
70116

71-
val account = accountRepository.fromName(service.accountName)
117+
val account = accountRepository.get().fromName(service.accountName)
72118
httpClientBuilder.get()
73119
.fromAccount(account)
74120
.build()
75121
.use { httpClient ->
76-
for (collection in collectionsToRegister)
122+
for (collection in subscribeTo)
77123
try {
78124
val expires = collection.pushSubscriptionExpires
79125
// calculate next run time, but use the duplicate interval for safety (times are not exact)
80-
val nextRun = Instant.now() + Duration.ofDays(2 * PushRegistrationWorkerManager.INTERVAL_DAYS)
126+
val nextRun = Instant.now() + Duration.ofDays(2 * WORKER_INTERVAL_DAYS)
81127
if (expires != null && expires >= nextRun.epochSecond)
82128
logger.fine("Push subscription for ${collection.url} is still valid until ${collection.pushSubscriptionExpires}")
83129
else {
84130
// no existing subscription or expiring soon
85131
logger.fine("Registering push subscription for ${collection.url}")
86-
registerSubscription(httpClient, collection, endpoint)
132+
subscribe(httpClient, collection, endpoint)
87133
}
88134
} catch (e: Exception) {
89135
logger.log(Level.WARNING, "Couldn't register subscription at CalDAV/CardDAV server", e)
90136
}
91137
}
92138
}
93139

94-
private fun registerSubscription(httpClient: HttpClient, collection: Collection, endpoint: PushEndpoint) {
140+
private suspend fun unsubscribeNotSyncable(service: Service) {
141+
val unsubscribeFrom = collectionRepository.getPushRegisteredAndNotSyncable(service.id)
142+
if (unsubscribeFrom.isEmpty())
143+
return
144+
145+
val account = accountRepository.get().fromName(service.accountName)
146+
httpClientBuilder.get()
147+
.fromAccount(account)
148+
.build()
149+
.use { httpClient ->
150+
for (collection in unsubscribeFrom)
151+
collection.pushSubscription?.toHttpUrlOrNull()?.let { url ->
152+
logger.info("Unregistering push for ${collection.url}")
153+
unsubscribe(httpClient, collection, url)
154+
}
155+
}
156+
}
157+
158+
/**
159+
* Called when no subscription is available (anymore) for the given service.
160+
*
161+
* Unsubscribes from all collections.
162+
*/
163+
internal suspend fun removeSubscription(serviceId: Long) {
164+
val service = serviceRepository.get(serviceId) ?: return
165+
val unsubscribeFrom = collectionRepository.getPushRegistered(service.id)
166+
if (unsubscribeFrom.isEmpty())
167+
return
168+
169+
val account = accountRepository.get().fromName(service.accountName)
170+
httpClientBuilder.get()
171+
.fromAccount(account)
172+
.build()
173+
.use { httpClient ->
174+
for (collection in unsubscribeFrom)
175+
collection.pushSubscription?.toHttpUrlOrNull()?.let { url ->
176+
logger.info("Unregistering push for ${collection.url}")
177+
unsubscribe(httpClient, collection, url)
178+
}
179+
}
180+
}
181+
182+
183+
/**
184+
* Registers the subscription to a given collection ("subscribe to a collection").
185+
*
186+
* @param httpClient HTTP client to use
187+
* @param collection collection to subscribe to
188+
* @param endpoint subscription to register
189+
*/
190+
private suspend fun subscribe(httpClient: HttpClient, collection: Collection, endpoint: PushEndpoint) {
95191
// requested expiration time: 3 days
96192
val requestedExpiration = Instant.now() + Duration.ofDays(3)
97193

@@ -124,26 +220,89 @@ class PushRegistrationManager @Inject constructor(
124220
}
125221
serializer.endDocument()
126222

127-
val xml = writer.toString().toRequestBody(DavResource.MIME_XML)
128-
DavCollection(httpClient.okHttpClient, collection.url).post(xml) { response ->
129-
if (response.isSuccessful) {
130-
// update subscription URL and expiration in DB
131-
val subscriptionUrl = response.header("Location")
132-
val expires = response.header("Expires")?.let { expiresDate ->
133-
HttpUtils.parseDate(expiresDate)
134-
} ?: requestedExpiration
135-
collectionRepository.updatePushSubscription(
136-
id = collection.id,
137-
subscriptionUrl = subscriptionUrl,
138-
expires = expires?.epochSecond
139-
)
140-
} else
141-
logger.warning("Couldn't register push for ${collection.url}: $response")
223+
runInterruptible {
224+
val xml = writer.toString().toRequestBody(DavResource.MIME_XML)
225+
DavCollection(httpClient.okHttpClient, collection.url).post(xml) { response ->
226+
if (response.isSuccessful) {
227+
// update subscription URL and expiration in DB
228+
val subscriptionUrl = response.header("Location")
229+
val expires = response.header("Expires")?.let { expiresDate ->
230+
HttpUtils.parseDate(expiresDate)
231+
} ?: requestedExpiration
232+
collectionRepository.updatePushSubscription(
233+
id = collection.id,
234+
subscriptionUrl = subscriptionUrl,
235+
expires = expires?.epochSecond
236+
)
237+
} else
238+
logger.warning("Couldn't register push for ${collection.url}: $response")
239+
}
240+
}
241+
}
242+
243+
private suspend fun unsubscribe(httpClient: HttpClient, collection: Collection, url: HttpUrl) {
244+
runInterruptible {
245+
try {
246+
DavResource(httpClient.okHttpClient, url).delete {
247+
// deleted
248+
}
249+
} catch (e: DavException) {
250+
logger.log(Level.WARNING, "Couldn't unregister push for ${collection.url}", e)
251+
}
252+
253+
// remove registration URL from DB in any case
254+
collectionRepository.updatePushSubscription(
255+
id = collection.id,
256+
subscriptionUrl = null,
257+
expires = null
258+
)
259+
}
260+
}
261+
262+
263+
/**
264+
* Determines whether there are any push-capable collections and updates the periodic worker accordingly.
265+
*
266+
* If there are push-capable collections, a unique periodic worker with an initial delay of 5 seconds is enqueued.
267+
* A potentially existing worker is replaced, so that the first run should be soon.
268+
*
269+
* Otherwise, a potentially existing worker is cancelled.
270+
*/
271+
fun updatePeriodicWorker() {
272+
val workerNeeded = runBlocking {
273+
collectionRepository.anyPushCapable()
274+
}
275+
276+
val workManager = WorkManager.getInstance(context)
277+
if (workerNeeded) {
278+
logger.info("Enqueuing periodic PushRegistrationWorker")
279+
workManager.enqueueUniquePeriodicWork(
280+
WORKER_UNIQUE_NAME,
281+
ExistingPeriodicWorkPolicy.UPDATE,
282+
PeriodicWorkRequest.Builder(PushRegistrationWorker::class, WORKER_INTERVAL_DAYS, TimeUnit.DAYS)
283+
.setInitialDelay(5, TimeUnit.SECONDS)
284+
.setConstraints(
285+
Constraints.Builder()
286+
.setRequiredNetworkType(NetworkType.CONNECTED)
287+
.build()
288+
)
289+
.setBackoffCriteria(BackoffPolicy.EXPONENTIAL, 1, TimeUnit.MINUTES)
290+
.build()
291+
)
292+
} else {
293+
logger.info("Cancelling periodic PushRegistrationWorker")
294+
workManager.cancelUniqueWork(WORKER_UNIQUE_NAME)
142295
}
143296
}
144297

145-
fun unregisterSubscription(serviceId: Long) {
146-
// TODO
298+
299+
companion object {
300+
301+
private const val WORKER_UNIQUE_NAME = "push-registration"
302+
const val WORKER_INTERVAL_DAYS = 1L
303+
304+
val dispatcher = Dispatchers.IO.limitedParallelism(1)
305+
147306
}
148307

149308
}

0 commit comments

Comments
 (0)