Skip to content

Commit de60f47

Browse files
authored
Handle huggingface 503 as a retry signal (#146)
1 parent 74ce0b3 commit de60f47

File tree

4 files changed

+85
-51
lines changed

4 files changed

+85
-51
lines changed

data/src/main/java/com/shifthackz/aisdv1/data/remote/HordeGenerationRemoteDataSource.kt

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -47,57 +47,52 @@ internal class HordeGenerationRemoteDataSource(
4747
?.let(hordeApi::cancelRequest)
4848
?: Completable.error(Throwable("No cached request id"))
4949

50-
private fun executeRequestChain(request: HordeGenerationAsyncRequest): Single<String> {
51-
val observableChain = hordeApi
52-
.generateAsync(request)
53-
.flatMapObservable { asyncStartResponse ->
54-
statusSource.id = asyncStartResponse.id
55-
asyncStartResponse.id?.let { id ->
56-
val pingObs = Observable
57-
.fromSingle(hordeApi.checkGeneration(id))
58-
.flatMap { pingResponse ->
59-
if (pingResponse.isPossible == false) {
60-
return@flatMap Observable.error(Throwable("Response is not possible"))
61-
}
62-
if (pingResponse.done == true) {
63-
return@flatMap Observable.fromSingle(hordeApi.checkStatus(id))
64-
}
65-
statusSource.update(
66-
HordeProcessStatus(
67-
waitTimeSeconds = pingResponse.waitTime ?: 0,
68-
queuePosition = pingResponse.queuePosition,
69-
)
70-
)
71-
return@flatMap Observable.error(RetryException())
50+
private fun executeRequestChain(request: HordeGenerationAsyncRequest) = hordeApi
51+
.generateAsync(request)
52+
.flatMapObservable { asyncStartResponse ->
53+
statusSource.id = asyncStartResponse.id
54+
asyncStartResponse.id?.let { id ->
55+
Observable
56+
.fromSingle(hordeApi.checkGeneration(id))
57+
.flatMap { pingResponse ->
58+
if (pingResponse.isPossible == false) {
59+
return@flatMap Observable.error(Throwable("Response is not possible"))
60+
}
61+
if (pingResponse.done == true) {
62+
return@flatMap Observable.fromSingle(hordeApi.checkStatus(id))
7263
}
73-
.retryWhen { obs ->
74-
obs.flatMap { t ->
75-
if (t is RetryException) {
76-
return@flatMap Observable
77-
.timer(HORDE_SOCKET_PING_TIME_SECONDS, TimeUnit.SECONDS)
78-
.doOnNext {
79-
debugLog("Retrying HORDE status check...")
80-
}
64+
statusSource.update(
65+
HordeProcessStatus(
66+
waitTimeSeconds = pingResponse.waitTime ?: 0,
67+
queuePosition = pingResponse.queuePosition,
68+
)
69+
)
70+
return@flatMap Observable.error(RetryException())
71+
}
72+
.retryWhen { obs ->
73+
obs.flatMap { t ->
74+
if (t is RetryException) Observable
75+
.timer(HORDE_SOCKET_PING_TIME_SECONDS, TimeUnit.SECONDS)
76+
.doOnNext {
77+
debugLog("Retrying HORDE status check...")
8178
}
82-
return@flatMap Observable.error(t)
83-
}
79+
else
80+
Observable.error(t)
8481
}
82+
}
83+
} ?: Observable.error(Throwable("Horde returned null generation id"))
84+
}
85+
.flatMapSingle {
86+
it.generations?.firstOrNull()?.let { generation ->
87+
val bytes = URL(generation.img).readBytes()
88+
val bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size)
89+
Single.just(bitmap)
90+
} ?: Single.error(Throwable("Error extracting image"))
91+
}
92+
.flatMapSingle { converter(BitmapToBase64Converter.Input(it)) }
93+
.map { it.base64ImageString }
94+
.let { Single.fromObservable(it) }
8595

86-
pingObs
87-
} ?: Observable.error(Throwable("Horde returned null generation id"))
88-
}
89-
.flatMapSingle {
90-
it.generations?.firstOrNull()?.let { generation ->
91-
val bytes = URL(generation.img).readBytes()
92-
val bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size)
93-
Single.just(bitmap)
94-
} ?: Single.error(Throwable("Error extracting image"))
95-
}
96-
.flatMapSingle { converter(BitmapToBase64Converter.Input(it)) }
97-
.map { it.base64ImageString }
98-
99-
return Single.fromObservable(observableChain)
100-
}
10196

10297
private class RetryException : Throwable()
10398

network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApi.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import android.graphics.Bitmap
44
import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest
55
import io.reactivex.rxjava3.core.Single
66
import okhttp3.ResponseBody
7+
import retrofit2.Response
78
import retrofit2.http.Body
89
import retrofit2.http.POST
910
import retrofit2.http.Path
@@ -22,6 +23,6 @@ interface HuggingFaceInferenceApi {
2223
fun generate(
2324
@Path("model") model: String,
2425
@Body request: HuggingFaceGenerationRequest,
25-
): Single<ResponseBody>
26+
): Single<Response<ResponseBody>>
2627
}
2728
}

network/src/main/java/com/shifthackz/aisdv1/network/api/huggingface/HuggingFaceInferenceApiImpl.kt

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package com.shifthackz.aisdv1.network.api.huggingface
22

33
import android.graphics.Bitmap
44
import android.graphics.BitmapFactory
5+
import com.shifthackz.aisdv1.core.common.log.debugLog
56
import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest
7+
import io.reactivex.rxjava3.core.Observable
68
import io.reactivex.rxjava3.core.Single
9+
import java.util.concurrent.TimeUnit
710

811
internal class HuggingFaceInferenceApiImpl(
912
private val rawApi: HuggingFaceInferenceApi.RawApi,
@@ -14,8 +17,33 @@ internal class HuggingFaceInferenceApiImpl(
1417
request: HuggingFaceGenerationRequest,
1518
): Single<Bitmap> = rawApi
1619
.generate(model, request)
17-
.map { body ->
18-
val bytes = body.bytes()
19-
BitmapFactory.decodeByteArray(bytes, 0, bytes.size)
20+
.flatMapObservable { response ->
21+
if (response.isSuccessful) {
22+
response.body()
23+
?.bytes()
24+
?.let { BitmapFactory.decodeByteArray(it, 0, it.size) }
25+
?.let { Observable.just(it) }
26+
?: Observable.error(Throwable("Body is null"))
27+
} else {
28+
when (response.code()) {
29+
503 -> Observable.error(RetryException())
30+
31+
else -> {
32+
Observable.error(Throwable(response.errorBody()?.string().toString()))
33+
}
34+
}
35+
}
36+
}
37+
.retryWhen { obs ->
38+
obs.flatMap { t ->
39+
if (t is RetryException) Observable
40+
.timer(20L, TimeUnit.SECONDS)
41+
.doOnNext { debugLog("Retrying hugging face due to 503...") }
42+
else
43+
Observable.error(t)
44+
}
2045
}
46+
.let { Single.fromObservable(it) }
47+
48+
private class RetryException : Throwable()
2149
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.shifthackz.aisdv1.network.response
2+
3+
import com.google.gson.annotations.SerializedName
4+
5+
data class HuggingFaceErrorResponse(
6+
@SerializedName("error")
7+
val error: String?,
8+
@SerializedName("estimated_time")
9+
val estimatedTime: Double?,
10+
)

0 commit comments

Comments
 (0)