Skip to content

Commit c7d93b1

Browse files
authored
Generation cancellation (#131)
* Generation cancel implementation * Generation cancel implementation * Cancel fetch random image
1 parent 686ebce commit c7d93b1

File tree

27 files changed

+290
-147
lines changed

27 files changed

+290
-147
lines changed

data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload
1111
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
1212
import com.shifthackz.aisdv1.network.api.horde.HordeRestApi
1313
import com.shifthackz.aisdv1.network.request.HordeGenerationAsyncRequest
14-
import com.shifthackz.aisdv1.network.response.HordeGenerationCheckResponse
1514
import io.reactivex.rxjava3.core.BackpressureStrategy
15+
import io.reactivex.rxjava3.core.Completable
1616
import io.reactivex.rxjava3.core.Flowable
1717
import io.reactivex.rxjava3.core.Observable
1818
import io.reactivex.rxjava3.core.Single
@@ -43,10 +43,15 @@ internal class HordeGenerationRemoteDataSource(
4343
.map { base64 -> payload to base64 }
4444
.map(Pair<ImageToImagePayload, String>::mapHordeToAiGenResult)
4545

46+
override fun interruptGeneration() = statusSource.id
47+
?.let(hordeApi::cancelRequest)
48+
?: Completable.error(Throwable("No cached request id"))
49+
4650
private fun executeRequestChain(request: HordeGenerationAsyncRequest): Single<String> {
4751
val observableChain = hordeApi
4852
.generateAsync(request)
4953
.flatMapObservable { asyncStartResponse ->
54+
statusSource.id = asyncStartResponse.id
5055
asyncStartResponse.id?.let { id ->
5156
val pingObs = Observable
5257
.fromSingle(hordeApi.checkGeneration(id))
@@ -63,7 +68,7 @@ internal class HordeGenerationRemoteDataSource(
6368
queuePosition = pingResponse.queuePosition,
6469
)
6570
)
66-
return@flatMap Observable.error(RetryException(pingResponse))
71+
return@flatMap Observable.error(RetryException())
6772
}
6873
.retryWhen { obs ->
6974
obs.flatMap { t ->
@@ -94,7 +99,7 @@ internal class HordeGenerationRemoteDataSource(
9499
return Single.fromObservable(observableChain)
95100
}
96101

97-
private class RetryException(val response: HordeGenerationCheckResponse): Throwable()
102+
private class RetryException : Throwable()
98103

99104
companion object {
100105
private const val HORDE_SOCKET_PING_TIME_SECONDS = 10L
@@ -103,6 +108,11 @@ internal class HordeGenerationRemoteDataSource(
103108

104109
internal class HordeStatusSource : HordeGenerationDataSource.StatusSource {
105110
private val processStatusSubject: PublishSubject<HordeProcessStatus> = PublishSubject.create()
111+
private var _id: String? = null
112+
113+
override var id: String?
114+
get() = _id
115+
set(value) { _id = value }
106116

107117
override fun observe(): Flowable<HordeProcessStatus> = processStatusSubject
108118
.toFlowable(BackpressureStrategy.LATEST)

data/src/main/java/com/shifthackz/aisdv1/data/remote/StableDiffusionGenerationRemoteDataSource.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
1010
import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi
1111
import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi.Companion.PATH_IMG_TO_IMG
1212
import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111RestApi.Companion.PATH_TXT_TO_IMG
13-
import com.shifthackz.aisdv1.network.extensions.withExceptionMapper
1413
import com.shifthackz.aisdv1.network.response.SdGenerationResponse
1514

1615
internal class StableDiffusionGenerationRemoteDataSource(
@@ -27,11 +26,12 @@ internal class StableDiffusionGenerationRemoteDataSource(
2726
.flatMap { url -> api.textToImage(url, payload.mapToRequest()) }
2827
.map { response -> payload to response }
2928
.map(Pair<TextToImagePayload, SdGenerationResponse>::mapToAiGenResult)
30-
.withExceptionMapper()
3129

3230
override fun imageToImage(payload: ImageToImagePayload) = serverUrlProvider(PATH_IMG_TO_IMG)
3331
.flatMap { url -> api.imageToImage(url, payload.mapToRequest()) }
3432
.map { response -> payload to response }
3533
.map(Pair<ImageToImagePayload, SdGenerationResponse>::mapToAiGenResult)
36-
.withExceptionMapper()
34+
35+
override fun interruptGeneration() = serverUrlProvider(Automatic1111RestApi.PATH_INTERRUPT)
36+
.flatMapCompletable(api::interrupt)
3737
}

data/src/main/java/com/shifthackz/aisdv1/data/repository/HordeGenerationRepositoryImpl.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,6 @@ internal class HordeGenerationRepositoryImpl(
3535
override fun generateFromImage(payload: ImageToImagePayload) = remoteDataSource
3636
.imageToImage(payload)
3737
.flatMap(::insertGenerationResult)
38+
39+
override fun interruptGeneration() = remoteDataSource.interruptGeneration()
3840
}

data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ internal class LocalDiffusionGenerationRepositoryImpl(
3939
else Single.error(Throwable("Model not downloaded"))
4040
}
4141

42+
override fun interruptGeneration() = localDiffusion.interrupt()
43+
4244
private fun generate(payload: TextToImagePayload) = localDiffusion
4345
.process(payload)
4446
.subscribeOn(schedulersProvider.computation)

data/src/main/java/com/shifthackz/aisdv1/data/repository/StableDiffusionGenerationRepositoryImpl.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,6 @@ internal class StableDiffusionGenerationRepositoryImpl(
4848

4949
return chain.flatMap(::insertGenerationResult)
5050
}
51+
52+
override fun interruptGeneration() = remoteDataSource.interruptGeneration()
5153
}

domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/AppVersionDataSource.kt

Lines changed: 0 additions & 15 deletions
This file was deleted.

domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/HordeGenerationDataSource.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
44
import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus
55
import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload
66
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
7+
import io.reactivex.rxjava3.core.Completable
78
import io.reactivex.rxjava3.core.Flowable
89
import io.reactivex.rxjava3.core.Single
910

@@ -12,9 +13,11 @@ sealed interface HordeGenerationDataSource {
1213
fun validateApiKey(): Single<Boolean>
1314
fun textToImage(payload: TextToImagePayload): Single<AiGenerationResult>
1415
fun imageToImage(payload: ImageToImagePayload): Single<AiGenerationResult>
16+
fun interruptGeneration(): Completable
1517
}
1618

1719
interface StatusSource : HordeGenerationDataSource {
20+
var id: String?
1821
fun observe(): Flowable<HordeProcessStatus>
1922
fun update(status: HordeProcessStatus)
2023
}

domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/StableDiffusionGenerationDataSource.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ sealed interface StableDiffusionGenerationDataSource {
1212
fun checkAvailability(url: String): Completable
1313
fun textToImage(payload: TextToImagePayload): Single<AiGenerationResult>
1414
fun imageToImage(payload: ImageToImagePayload): Single<AiGenerationResult>
15+
fun interruptGeneration(): Completable
1516
}
1617
}

domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,77 @@
11
package com.shifthackz.aisdv1.domain.di
22

3-
import com.shifthackz.aisdv1.domain.interactor.wakelock.*
4-
import com.shifthackz.aisdv1.domain.usecase.caching.*
5-
import com.shifthackz.aisdv1.domain.usecase.connectivity.*
6-
import com.shifthackz.aisdv1.domain.usecase.debug.*
7-
import com.shifthackz.aisdv1.domain.usecase.downloadable.*
8-
import com.shifthackz.aisdv1.domain.usecase.gallery.*
9-
import com.shifthackz.aisdv1.domain.usecase.generation.*
3+
import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor
4+
import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActorImpl
5+
import com.shifthackz.aisdv1.domain.usecase.caching.ClearAppCacheUseCase
6+
import com.shifthackz.aisdv1.domain.usecase.caching.ClearAppCacheUseCaseImpl
7+
import com.shifthackz.aisdv1.domain.usecase.caching.DataPreLoaderUseCase
8+
import com.shifthackz.aisdv1.domain.usecase.caching.DataPreLoaderUseCaseImpl
9+
import com.shifthackz.aisdv1.domain.usecase.caching.GetLastResultFromCacheUseCase
10+
import com.shifthackz.aisdv1.domain.usecase.caching.GetLastResultFromCacheUseCaseImpl
11+
import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCase
12+
import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCaseImpl
13+
import com.shifthackz.aisdv1.domain.usecase.connectivity.ObserveSeverConnectivityUseCase
14+
import com.shifthackz.aisdv1.domain.usecase.connectivity.ObserveSeverConnectivityUseCaseImpl
15+
import com.shifthackz.aisdv1.domain.usecase.connectivity.PingStableDiffusionServiceUseCase
16+
import com.shifthackz.aisdv1.domain.usecase.connectivity.PingStableDiffusionServiceUseCaseImpl
17+
import com.shifthackz.aisdv1.domain.usecase.connectivity.TestConnectivityUseCase
18+
import com.shifthackz.aisdv1.domain.usecase.connectivity.TestConnectivityUseCaseImpl
19+
import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCase
20+
import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCaseImpl
21+
import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCase
22+
import com.shifthackz.aisdv1.domain.usecase.debug.DebugInsertBadBase64UseCaseImpl
23+
import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase
24+
import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCaseImpl
25+
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase
26+
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl
27+
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase
28+
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl
29+
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase
30+
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCaseImpl
31+
import com.shifthackz.aisdv1.domain.usecase.gallery.GetAllGalleryUseCase
32+
import com.shifthackz.aisdv1.domain.usecase.gallery.GetAllGalleryUseCaseImpl
33+
import com.shifthackz.aisdv1.domain.usecase.gallery.GetMediaStoreInfoUseCase
34+
import com.shifthackz.aisdv1.domain.usecase.gallery.GetMediaStoreInfoUseCaseImpl
35+
import com.shifthackz.aisdv1.domain.usecase.generation.GetGenerationResultPagedUseCase
36+
import com.shifthackz.aisdv1.domain.usecase.generation.GetGenerationResultPagedUseCaseImpl
37+
import com.shifthackz.aisdv1.domain.usecase.generation.GetGenerationResultUseCase
38+
import com.shifthackz.aisdv1.domain.usecase.generation.GetGenerationResultUseCaseImpl
39+
import com.shifthackz.aisdv1.domain.usecase.generation.GetRandomImageUseCase
40+
import com.shifthackz.aisdv1.domain.usecase.generation.GetRandomImageUseCaseImpl
41+
import com.shifthackz.aisdv1.domain.usecase.generation.ImageToImageUseCase
42+
import com.shifthackz.aisdv1.domain.usecase.generation.ImageToImageUseCaseImpl
43+
import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase
44+
import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCaseImpl
45+
import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCase
46+
import com.shifthackz.aisdv1.domain.usecase.generation.ObserveHordeProcessStatusUseCaseImpl
47+
import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCase
48+
import com.shifthackz.aisdv1.domain.usecase.generation.ObserveLocalDiffusionProcessStatusUseCaseImpl
49+
import com.shifthackz.aisdv1.domain.usecase.generation.SaveGenerationResultUseCase
50+
import com.shifthackz.aisdv1.domain.usecase.generation.SaveGenerationResultUseCaseImpl
51+
import com.shifthackz.aisdv1.domain.usecase.generation.TextToImageUseCase
52+
import com.shifthackz.aisdv1.domain.usecase.generation.TextToImageUseCaseImpl
1053
import com.shifthackz.aisdv1.domain.usecase.sdembedding.FetchAndGetEmbeddingsUseCase
1154
import com.shifthackz.aisdv1.domain.usecase.sdembedding.FetchAndGetEmbeddingsUseCaseImpl
1255
import com.shifthackz.aisdv1.domain.usecase.sdhypernet.FetchAndGetHyperNetworksUseCase
1356
import com.shifthackz.aisdv1.domain.usecase.sdhypernet.FetchAndGetHyperNetworksUseCaseImpl
14-
import com.shifthackz.aisdv1.domain.usecase.sdlora.*
15-
import com.shifthackz.aisdv1.domain.usecase.sdmodel.*
16-
import com.shifthackz.aisdv1.domain.usecase.sdsampler.*
17-
import com.shifthackz.aisdv1.domain.usecase.settings.*
18-
import com.shifthackz.aisdv1.domain.usecase.splash.*
19-
import com.shifthackz.aisdv1.domain.usecase.wakelock.*
57+
import com.shifthackz.aisdv1.domain.usecase.sdlora.FetchAndGetLorasUseCase
58+
import com.shifthackz.aisdv1.domain.usecase.sdlora.FetchAndGetLorasUseCaseImpl
59+
import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase
60+
import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCaseImpl
61+
import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase
62+
import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCaseImpl
63+
import com.shifthackz.aisdv1.domain.usecase.sdsampler.GetStableDiffusionSamplersUseCase
64+
import com.shifthackz.aisdv1.domain.usecase.sdsampler.GetStableDiffusionSamplersUseCaseImpl
65+
import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase
66+
import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCaseImpl
67+
import com.shifthackz.aisdv1.domain.usecase.settings.SetServerConfigurationUseCase
68+
import com.shifthackz.aisdv1.domain.usecase.settings.SetServerConfigurationUseCaseImpl
69+
import com.shifthackz.aisdv1.domain.usecase.splash.SplashNavigationUseCase
70+
import com.shifthackz.aisdv1.domain.usecase.splash.SplashNavigationUseCaseImpl
71+
import com.shifthackz.aisdv1.domain.usecase.wakelock.AcquireWakelockUseCase
72+
import com.shifthackz.aisdv1.domain.usecase.wakelock.AcquireWakelockUseCaseImpl
73+
import com.shifthackz.aisdv1.domain.usecase.wakelock.ReleaseWakeLockUseCase
74+
import com.shifthackz.aisdv1.domain.usecase.wakelock.ReleaseWakeLockUseCaseImpl
2075
import org.koin.core.module.dsl.factoryOf
2176
import org.koin.dsl.bind
2277
import org.koin.dsl.module
@@ -55,6 +110,7 @@ internal val useCasesModule = module {
55110
factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class
56111
factoryOf(::AcquireWakelockUseCaseImpl) bind AcquireWakelockUseCase::class
57112
factoryOf(::ReleaseWakeLockUseCaseImpl) bind ReleaseWakeLockUseCase::class
113+
factoryOf(::InterruptGenerationUseCaseImpl) bind InterruptGenerationUseCase::class
58114
}
59115

60116
internal val interActorsModule = module {

domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
package com.shifthackz.aisdv1.domain.feature.diffusion
22

33
import android.graphics.Bitmap
4-
import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
54
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
65
import io.reactivex.rxjava3.core.Completable
7-
import io.reactivex.rxjava3.core.Flowable
86
import io.reactivex.rxjava3.core.Observable
97
import io.reactivex.rxjava3.core.Single
108

119
interface LocalDiffusion {
1210
fun process(payload: TextToImagePayload): Single<Bitmap>
11+
fun interrupt(): Completable
1312
fun observeStatus(): Observable<Status>
1413

1514
data class Status(val current: Int, val total: Int)

0 commit comments

Comments
 (0)