Skip to content

Commit fa84fdf

Browse files
committed
- Refactor: Enable concurrent model downloads
This change refactors the model download service to support downloading multiple models simultaneously. Key changes: - Replaced the single `downloadState` with a `downloadStates` map, allowing individual tracking of each download's progress, keyed by model ID. - Each download now runs in its own coroutine job and displays a separate notification, preventing conflicts. - Cancellation is now specific to a model ID, allowing users to cancel individual downloads without affecting others. - Implemented robust cleanup logic to delete temporary files and remove download state upon completion, cancellation, or error. - Added a new `EmbeddingModelDownloadWorker` to handle the download of the embedding model in the background, separate from the main model download service.
1 parent 8e4d8d2 commit fa84fdf

File tree

6 files changed

+286
-81
lines changed

6 files changed

+286
-81
lines changed

app/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ android {
2222
applicationId = "com.dark.tool_neuron"
2323
minSdk = 31
2424
targetSdk = 36
25-
versionCode = 16
25+
versionCode = 17
2626
versionName = "1.1.2-Fix"
2727
ndk {
2828
abiFilters += listOf("arm64-v8a", "x86_64")

app/src/main/AndroidManifest.xml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC" />
88

99

10-
11-
1210
<application
1311
android:name=".NVApplication"
1412
android:allowBackup="true"

app/src/main/java/com/dark/tool_neuron/service/ModelDownloadService.kt

Lines changed: 130 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ import java.util.zip.ZipInputStream
3333
class ModelDownloadService : Service() {
3434

3535
private val serviceScope = CoroutineScope(Dispatchers.IO + Job())
36-
private var downloadJob: Job? = null
36+
private val downloadJobs = mutableMapOf<String, Job>()
37+
private var notificationIdCounter = NOTIFICATION_ID
3738

3839
private val notificationManager by lazy {
3940
getSystemService(NOTIFICATION_SERVICE) as NotificationManager
@@ -46,8 +47,8 @@ class ModelDownloadService : Service() {
4647
private const val NOTIFICATION_CHANNEL_ID = "model_download_channel"
4748
private const val NOTIFICATION_ID = 3001
4849

49-
private val _downloadState = MutableStateFlow<DownloadState>(DownloadState.Idle)
50-
val downloadState: StateFlow<DownloadState> = _downloadState
50+
private val _downloadStates = MutableStateFlow<Map<String, DownloadState>>(emptyMap())
51+
val downloadStates: StateFlow<Map<String, DownloadState>> = _downloadStates
5152

5253
const val ACTION_START_DOWNLOAD = "action_start_download"
5354
const val ACTION_CANCEL_DOWNLOAD = "action_cancel_download"
@@ -62,7 +63,6 @@ class ModelDownloadService : Service() {
6263
}
6364

6465
sealed class DownloadState {
65-
object Idle : DownloadState()
6666
data class Downloading(
6767
val modelId: String,
6868
val progress: Float,
@@ -74,6 +74,15 @@ class ModelDownloadService : Service() {
7474
data class Processing(val modelId: String) : DownloadState()
7575
data class Success(val modelId: String) : DownloadState()
7676
data class Error(val modelId: String, val message: String) : DownloadState()
77+
data class Cancelled(val modelId: String) : DownloadState()
78+
}
79+
80+
private fun updateDownloadState(modelId: String, state: DownloadState?) {
81+
_downloadStates.value = if (state == null) {
82+
_downloadStates.value - modelId
83+
} else {
84+
_downloadStates.value + (modelId to state)
85+
}
7786
}
7887

7988
override fun onCreate() {
@@ -105,7 +114,10 @@ class ModelDownloadService : Service() {
105114
}
106115

107116
ACTION_CANCEL_DOWNLOAD -> {
108-
cancelDownload()
117+
val modelId = intent.getStringExtra(EXTRA_MODEL_ID)
118+
if (modelId != null) {
119+
cancelDownload(modelId)
120+
}
109121
}
110122
}
111123
return START_NOT_STICKY
@@ -120,22 +132,25 @@ class ModelDownloadService : Service() {
120132
runOnCpu: Boolean,
121133
textEmbeddingSize: Int
122134
) {
123-
downloadJob?.cancel()
124-
downloadJob = serviceScope.launch {
135+
// Cancel existing download for this model if any
136+
downloadJobs[modelId]?.cancel()
137+
138+
val notificationId = ++notificationIdCounter
139+
val job = serviceScope.launch {
125140
var tempFile: File? = null
126141
var extractTempDir: File? = null
127142
try {
128-
_downloadState.value = DownloadState.Downloading(modelId, 0f, 0, 0)
143+
updateDownloadState(modelId, DownloadState.Downloading(modelId, 0f, 0, 0))
129144

130-
val tempDir = File(filesDir, "temp_downloads")
145+
val tempDir = File(filesDir, "temp_downloads/$modelId")
131146
if (tempDir.exists()) {
132147
tempDir.deleteRecursively()
133148
}
134149
tempDir.mkdirs()
135150

136151
tempFile = File(tempDir, "${modelId}_${System.currentTimeMillis()}.tmp")
137152

138-
downloadFile(fileUrl, tempFile, modelId, modelName)
153+
downloadFile(fileUrl, tempFile, modelId, modelName, notificationId)
139154

140155
when (modelType) {
141156
"SD" -> {
@@ -153,10 +168,10 @@ class ModelDownloadService : Service() {
153168
extractTempDir = File(tempDir, "${modelId}_extract")
154169
extractTempDir.mkdirs()
155170

156-
_downloadState.value = DownloadState.Extracting(modelId)
157-
updateNotification(modelName, 0f, isExtracting = true)
171+
updateDownloadState(modelId, DownloadState.Extracting(modelId))
172+
updateNotification(modelName, 0f, notificationId, isExtracting = true)
158173

159-
unzipFile(tempFile, extractTempDir)
174+
unzipFile(tempFile, extractTempDir, modelId)
160175

161176
extractTempDir.listFiles()?.forEach { file ->
162177
file.copyRecursively(File(modelDir, file.name), overwrite = true)
@@ -170,8 +185,8 @@ class ModelDownloadService : Service() {
170185
tempFile.copyTo(File(modelDir, tempFile.name), overwrite = true)
171186
}
172187

173-
_downloadState.value = DownloadState.Processing(modelId)
174-
updateNotification(modelName, 0f, isProcessing = true)
188+
updateDownloadState(modelId, DownloadState.Processing(modelId))
189+
updateNotification(modelName, 0f, notificationId, isProcessing = true)
175190

176191
insertModelToDatabase(
177192
modelId = modelId,
@@ -195,8 +210,8 @@ class ModelDownloadService : Service() {
195210

196211
tempFile.copyTo(targetFile, overwrite = true)
197212

198-
_downloadState.value = DownloadState.Processing(modelId)
199-
updateNotification(modelName, 0f, isProcessing = true)
213+
updateDownloadState(modelId, DownloadState.Processing(modelId))
214+
updateNotification(modelName, 0f, notificationId, isProcessing = true)
200215

201216
insertModelToDatabase(
202217
modelId = modelId,
@@ -211,82 +226,129 @@ class ModelDownloadService : Service() {
211226

212227
tempFile?.delete()
213228
tempFile = null
229+
tempDir.deleteRecursively()
214230

215-
_downloadState.value = DownloadState.Success(modelId)
216-
updateNotification(modelName, 100f, isSuccess = true)
231+
updateDownloadState(modelId, DownloadState.Success(modelId))
232+
updateNotification(modelName, 100f, notificationId, isSuccess = true)
217233

218234
withContext(Dispatchers.Main) {
219235
kotlinx.coroutines.delay(2000)
220-
_downloadState.value = DownloadState.Idle
221-
stopForeground(STOP_FOREGROUND_REMOVE)
222-
stopSelf()
236+
updateDownloadState(modelId, null)
237+
downloadJobs.remove(modelId)
238+
239+
if (downloadJobs.isEmpty()) {
240+
stopForeground(STOP_FOREGROUND_REMOVE)
241+
stopSelf()
242+
}
223243
}
224244

245+
} catch (e: kotlinx.coroutines.CancellationException) {
246+
tempFile?.delete()
247+
extractTempDir?.deleteRecursively()
248+
File(filesDir, "temp_downloads/$modelId").deleteRecursively()
249+
250+
updateDownloadState(modelId, DownloadState.Cancelled(modelId))
251+
updateNotification(modelName, 0f, notificationId, isCancelled = true)
252+
253+
withContext(Dispatchers.Main) {
254+
kotlinx.coroutines.delay(2000)
255+
updateDownloadState(modelId, null)
256+
downloadJobs.remove(modelId)
257+
258+
if (downloadJobs.isEmpty()) {
259+
stopForeground(STOP_FOREGROUND_REMOVE)
260+
stopSelf()
261+
}
262+
}
225263
} catch (e: Exception) {
226264
tempFile?.delete()
227265
extractTempDir?.deleteRecursively()
266+
File(filesDir, "temp_downloads/$modelId").deleteRecursively()
228267

229-
_downloadState.value = DownloadState.Error(modelId, e.message ?: "Unknown error")
230-
updateNotification(modelName, 0f, error = e.message)
268+
updateDownloadState(modelId, DownloadState.Error(modelId, e.message ?: "Unknown error"))
269+
updateNotification(modelName, 0f, notificationId, error = e.message)
231270

232271
withContext(Dispatchers.Main) {
233272
kotlinx.coroutines.delay(3000)
234-
_downloadState.value = DownloadState.Idle
235-
stopForeground(STOP_FOREGROUND_REMOVE)
236-
stopSelf()
273+
updateDownloadState(modelId, null)
274+
downloadJobs.remove(modelId)
275+
276+
if (downloadJobs.isEmpty()) {
277+
stopForeground(STOP_FOREGROUND_REMOVE)
278+
stopSelf()
279+
}
237280
}
238281
}
239282
}
283+
284+
downloadJobs[modelId] = job
240285
}
241286

242287
private suspend fun downloadFile(
243-
url: String, destFile: File, modelId: String, modelName: String
288+
url: String, destFile: File, modelId: String, modelName: String, notificationId: Int
244289
) = withContext(Dispatchers.IO) {
245290
val request = Request.Builder().url(url).build()
291+
val call = client.newCall(request)
246292

247-
client.newCall(request).execute().use { response ->
248-
if (!response.isSuccessful) {
249-
throw Exception("Download failed with code: ${response.code}")
250-
}
251-
252-
val body = response.body ?: throw Exception("Response body is null")
253-
val totalBytes = body.contentLength()
254-
var downloadedBytes = 0L
255-
var lastUpdateTime = 0L
293+
try {
294+
call.execute().use { response ->
295+
if (!response.isSuccessful) {
296+
throw Exception("Download failed with code: ${response.code}")
297+
}
256298

257-
FileOutputStream(destFile).buffered().use { output ->
258-
body.byteStream().buffered().use { input ->
259-
val buffer = ByteArray(32 * 1024)
260-
var bytes: Int
299+
val body = response.body ?: throw Exception("Response body is null")
300+
val totalBytes = body.contentLength()
301+
var downloadedBytes = 0L
302+
var lastUpdateTime = 0L
303+
304+
FileOutputStream(destFile).buffered().use { output ->
305+
body.byteStream().buffered().use { input ->
306+
val buffer = ByteArray(32 * 1024)
307+
var bytes: Int
308+
309+
while (input.read(buffer).also { bytes = it } != -1) {
310+
// Check for cancellation
311+
if (!downloadJobs.containsKey(modelId) || downloadJobs[modelId]?.isCancelled == true) {
312+
call.cancel()
313+
throw kotlinx.coroutines.CancellationException("Download cancelled")
314+
}
261315

262-
while (input.read(buffer).also { bytes = it } != -1) {
263-
output.write(buffer, 0, bytes)
264-
downloadedBytes += bytes
316+
output.write(buffer, 0, bytes)
317+
downloadedBytes += bytes
265318

266-
val currentTime = System.currentTimeMillis()
267-
if (currentTime - lastUpdateTime >= 500 || downloadedBytes == totalBytes) {
268-
lastUpdateTime = currentTime
269-
val progress = if (totalBytes > 0) {
270-
downloadedBytes.toFloat() / totalBytes
271-
} else 0f
319+
val currentTime = System.currentTimeMillis()
320+
if (currentTime - lastUpdateTime >= 500 || downloadedBytes == totalBytes) {
321+
lastUpdateTime = currentTime
322+
val progress = if (totalBytes > 0) {
323+
downloadedBytes.toFloat() / totalBytes
324+
} else 0f
272325

273-
_downloadState.value = DownloadState.Downloading(
274-
modelId, progress, downloadedBytes, totalBytes
275-
)
326+
updateDownloadState(modelId, DownloadState.Downloading(
327+
modelId, progress, downloadedBytes, totalBytes
328+
))
276329

277-
updateNotification(modelName, progress)
330+
updateNotification(modelName, progress, notificationId)
331+
}
278332
}
279333
}
280334
}
281335
}
336+
} catch (e: Exception) {
337+
call.cancel()
338+
throw e
282339
}
283340
}
284341

285-
private suspend fun unzipFile(zipFile: File, destDir: File) = withContext(Dispatchers.IO) {
342+
private suspend fun unzipFile(zipFile: File, destDir: File, modelId: String) = withContext(Dispatchers.IO) {
286343
ZipInputStream(zipFile.inputStream().buffered()).use { zis ->
287344
var entry = zis.nextEntry
288345

289346
while (entry != null) {
347+
// Check for cancellation
348+
if (!downloadJobs.containsKey(modelId) || downloadJobs[modelId]?.isCancelled == true) {
349+
throw kotlinx.coroutines.CancellationException("Extraction cancelled")
350+
}
351+
290352
if (!entry.isDirectory) {
291353
val fileName = entry.name.substringAfterLast('/')
292354
if (fileName.isNotEmpty() && !fileName.startsWith(".") && !fileName.startsWith("__MACOSX")) {
@@ -379,11 +441,8 @@ class ModelDownloadService : Service() {
379441
repository.insertConfig(config)
380442
}
381443

382-
private fun cancelDownload() {
383-
downloadJob?.cancel()
384-
_downloadState.value = DownloadState.Idle
385-
stopForeground(STOP_FOREGROUND_REMOVE)
386-
stopSelf()
444+
private fun cancelDownload(modelId: String) {
445+
downloadJobs[modelId]?.cancel()
387446
}
388447

389448
private fun createNotificationChannel() {
@@ -416,10 +475,12 @@ class ModelDownloadService : Service() {
416475
private fun updateNotification(
417476
modelName: String,
418477
progress: Float,
478+
notificationId: Int,
419479
isSuccess: Boolean = false,
420480
error: String? = null,
421481
isExtracting: Boolean = false,
422-
isProcessing: Boolean = false
482+
isProcessing: Boolean = false,
483+
isCancelled: Boolean = false
423484
) {
424485
val notification = when {
425486
isSuccess -> {
@@ -429,6 +490,13 @@ class ModelDownloadService : Service() {
429490
.build()
430491
}
431492

493+
isCancelled -> {
494+
NotificationCompat.Builder(this, NOTIFICATION_CHANNEL_ID)
495+
.setContentTitle("Download Cancelled").setContentText(modelName)
496+
.setSmallIcon(android.R.drawable.ic_menu_close_clear_cancel).setOngoing(false)
497+
.build()
498+
}
499+
432500
error != null -> {
433501
NotificationCompat.Builder(this, NOTIFICATION_CHANNEL_ID)
434502
.setContentTitle("Download Failed").setContentText(error)
@@ -440,7 +508,7 @@ class ModelDownloadService : Service() {
440508
}
441509
}
442510

443-
notificationManager.notify(NOTIFICATION_ID, notification)
511+
notificationManager.notify(notificationId, notification)
444512
}
445513

446514
override fun onBind(intent: Intent?): IBinder? = null

0 commit comments

Comments
 (0)