Skip to content

Commit 715ab58

Browse files
authored
OpenAI provider implementation (#134)
* OpenAI provider implementation | Patch 1 * OpenAI provider implementation | Patch 2 * OpenAI provider implementation | Patch 3
1 parent 4a62ea2 commit 715ab58

File tree

53 files changed

+1016
-338
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1016
-338
lines changed

app/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ android {
2121
buildConfigField "String", "HUGGING_FACE_URL", "\"https://huggingface.co/\""
2222
buildConfigField "String", "HUGGING_FACE_INFERENCE_URL", "\"https://api-inference.huggingface.co/\""
2323
buildConfigField "String", "HORDE_AI_URL", "\"https://stablehorde.net/\""
24+
buildConfigField "String", "OPEN_AI_URL", "\"https://api.openai.com/\""
2425

2526
buildConfigField "String", "HORDE_AI_SIGN_UP_URL", "\"https://stablehorde.net/register\""
2627
buildConfigField "String", "HUGGING_FACE_INFO_URL", "\"https://huggingface.co/docs/api-inference/index\""

app/src/main/java/com/shifthackz/aisdv1/app/AiStableDiffusionClientApp.kt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.shifthackz.aisdv1.app
22

3+
import android.annotation.SuppressLint
34
import android.app.Application
5+
import android.database.CursorWindow
46
import android.os.StrictMode
57
import android.os.StrictMode.VmPolicy
68
import com.shifthackz.aisdv1.app.di.featureModule
@@ -20,14 +22,32 @@ import org.koin.android.ext.koin.androidContext
2022
import org.koin.core.context.startKoin
2123
import timber.log.Timber
2224

25+
2326
class AiStableDiffusionClientApp : Application() {
2427

2528
override fun onCreate() {
2629
super.onCreate()
30+
initializeLogging()
2731
StrictMode.setVmPolicy(VmPolicy.Builder().build())
2832
Thread.currentThread().setUncaughtExceptionHandler { _, t -> errorLog(t) }
33+
initializeCursorSize()
2934
initializeKoin()
30-
initializeLogging()
35+
}
36+
37+
/**
38+
* Overrides the cursor size to prevent Room DB fail with big base64.
39+
*
40+
* Reference: https://stackoverflow.com/questions/51959944/sqliteblobtoobigexception-row-too-big-to-fit-into-cursorwindow-requiredpos-0-t
41+
*/
42+
@SuppressLint("DiscouragedPrivateApi")
43+
private fun initializeCursorSize() {
44+
try {
45+
val field = CursorWindow::class.java.getDeclaredField("sCursorWindowSize")
46+
field.isAccessible = true
47+
field.set(null, 100 * 1024 * 1024) // 100 Mb
48+
} catch (e: Exception) {
49+
errorLog(e)
50+
}
3151
}
3252

3353
private fun initializeKoin() = startKoin {

app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ val providersModule = module {
4343
override val imageCdnApiUrl: String = BuildConfig.IMAGE_CDN_URL
4444
override val huggingFaceApiUrl: String = BuildConfig.HUGGING_FACE_URL
4545
override val huggingFaceInferenceApiUrl = BuildConfig.HUGGING_FACE_INFERENCE_URL
46+
override val openAiApiUrl: String = BuildConfig.OPEN_AI_URL
4647
}
4748
}
4849

@@ -51,13 +52,21 @@ val providersModule = module {
5152
val preference = get<PreferenceManager>()
5253
when (preference.source) {
5354
ServerSource.HORDE -> {
54-
val key = preference.hordeApiKey.takeIf(String::isNotEmpty) ?: DEFAULT_HORDE_API_KEY
55+
val key =
56+
preference.hordeApiKey.takeIf(String::isNotEmpty) ?: DEFAULT_HORDE_API_KEY
5557
NetworkHeaders.API_KEY to key
5658
}
59+
5760
ServerSource.HUGGING_FACE -> {
5861
val key = "${NetworkPrefixes.BEARER} ${preference.huggingFaceApiKey}"
5962
NetworkHeaders.AUTHORIZATION to key
6063
}
64+
65+
ServerSource.OPEN_AI -> {
66+
val key = "${NetworkPrefixes.BEARER} ${preference.openAiApiKey}"
67+
NetworkHeaders.AUTHORIZATION to key
68+
}
69+
6170
else -> null
6271
}
6372
}
@@ -72,6 +81,7 @@ val providersModule = module {
7281
login = credentials.login,
7382
password = credentials.password,
7483
)
84+
7585
else -> CredentialsProvider.Data.None
7686
}
7787
}
@@ -120,7 +130,8 @@ val providersModule = module {
120130
override val providerPath: String = "${androidApplication().packageName}.fileprovider"
121131
override val imagesCacheDirPath: String = "${androidApplication().cacheDir}/images"
122132
override val logsCacheDirPath: String = "${androidApplication().cacheDir}/logs"
123-
override val localModelDirPath: String = "${androidApplication().filesDir.absolutePath}/model"
133+
override val localModelDirPath: String =
134+
"${androidApplication().filesDir.absolutePath}/model"
124135
}
125136
}
126137

data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.shifthackz.aisdv1.data.remote.HordeGenerationRemoteDataSource
88
import com.shifthackz.aisdv1.data.remote.HordeStatusSource
99
import com.shifthackz.aisdv1.data.remote.HuggingFaceGenerationRemoteDataSource
1010
import com.shifthackz.aisdv1.data.remote.HuggingFaceModelsRemoteDataSource
11+
import com.shifthackz.aisdv1.data.remote.OpenAiGenerationRemoteDataSource
1112
import com.shifthackz.aisdv1.data.remote.RandomImageRemoteDataSource
1213
import com.shifthackz.aisdv1.data.remote.ServerConfigurationRemoteDataSource
1314
import com.shifthackz.aisdv1.data.remote.StableDiffusionEmbeddingsRemoteDataSource
@@ -20,6 +21,7 @@ import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
2021
import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource
2122
import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource
2223
import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource
24+
import com.shifthackz.aisdv1.domain.datasource.OpenAiGenerationDataSource
2325
import com.shifthackz.aisdv1.domain.datasource.RandomImageDataSource
2426
import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource
2527
import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource
@@ -52,6 +54,7 @@ val remoteDataSourceModule = module {
5254
singleOf(::HordeStatusSource) bind HordeGenerationDataSource.StatusSource::class
5355
factoryOf(::HordeGenerationRemoteDataSource) bind HordeGenerationDataSource.Remote::class
5456
factoryOf(::HuggingFaceGenerationRemoteDataSource) bind HuggingFaceGenerationDataSource.Remote::class
57+
factoryOf(::OpenAiGenerationRemoteDataSource) bind OpenAiGenerationDataSource.Remote::class
5558
factoryOf(::StableDiffusionGenerationRemoteDataSource) bind StableDiffusionGenerationDataSource.Remote::class
5659
factoryOf(::StableDiffusionSamplersRemoteDataSource) bind StableDiffusionSamplersDataSource.Remote::class
5760
factoryOf(::StableDiffusionModelsRemoteDataSource) bind StableDiffusionModelsDataSource.Remote::class

data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.shifthackz.aisdv1.data.repository.HordeGenerationRepositoryImpl
88
import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl
99
import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl
1010
import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl
11+
import com.shifthackz.aisdv1.data.repository.OpenAiGenerationRepositoryImpl
1112
import com.shifthackz.aisdv1.data.repository.RandomImageRepositoryImpl
1213
import com.shifthackz.aisdv1.data.repository.ServerConfigurationRepositoryImpl
1314
import com.shifthackz.aisdv1.data.repository.StableDiffusionEmbeddingsRepositoryImpl
@@ -24,6 +25,7 @@ import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository
2425
import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository
2526
import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository
2627
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository
28+
import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository
2729
import com.shifthackz.aisdv1.domain.repository.RandomImageRepository
2830
import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository
2931
import com.shifthackz.aisdv1.domain.repository.StableDiffusionEmbeddingsRepository
@@ -51,6 +53,7 @@ val repositoryModule = module {
5153
factoryOf(::LocalDiffusionGenerationRepositoryImpl) bind LocalDiffusionGenerationRepository::class
5254
factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class
5355
factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class
56+
factoryOf(::OpenAiGenerationRepositoryImpl) bind OpenAiGenerationRepository::class
5457
factoryOf(::StableDiffusionGenerationRepositoryImpl) bind StableDiffusionGenerationRepository::class
5558
factoryOf(::StableDiffusionModelsRepositoryImpl) bind StableDiffusionModelsRepository::class
5659
factoryOf(::StableDiffusionSamplersRepositoryImpl) bind StableDiffusionSamplersRepository::class

data/src/main/java/com/shifthackz/aisdv1/data/mappers/TextToImagePayloadMappers.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package com.shifthackz.aisdv1.data.mappers
22

33
import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
4+
import com.shifthackz.aisdv1.domain.entity.OpenAiModel
45
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
56
import com.shifthackz.aisdv1.network.request.HordeGenerationAsyncRequest
67
import com.shifthackz.aisdv1.network.request.HuggingFaceGenerationRequest
8+
import com.shifthackz.aisdv1.network.request.OpenAiRequest
79
import com.shifthackz.aisdv1.network.request.TextToImageRequest
810
import com.shifthackz.aisdv1.network.response.SdGenerationResponse
911
import java.util.Date
@@ -59,6 +61,17 @@ fun TextToImagePayload.mapToHuggingFaceRequest(): HuggingFaceGenerationRequest =
5961
)
6062
}
6163

64+
fun TextToImagePayload.mapToOpenAiRequest(): OpenAiRequest = with(this) {
65+
OpenAiRequest(
66+
prompt = prompt,
67+
model = openAiModel?.alias ?: OpenAiModel.DALL_E_2.alias,
68+
size = "${width}x${height}",
69+
responseFormat = "b64_json",
70+
quality = quality,
71+
style = style,
72+
)
73+
}
74+
6275
fun Pair<TextToImagePayload, SdGenerationResponse>.mapToAiGenResult(): AiGenerationResult =
6376
let { (payload, response) ->
6477
AiGenerationResult(

data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ class PreferenceManagerImpl(
7777
.apply()
7878
.also { onPreferencesChanged() }
7979

80+
override var openAiApiKey: String
81+
get() = preferences.getString(KEY_OPEN_AI_API_KEY, "") ?: ""
82+
set(value) = preferences.edit()
83+
.putString(KEY_OPEN_AI_API_KEY, value)
84+
.apply()
85+
.also { onPreferencesChanged() }
86+
8087
override var huggingFaceApiKey: String
8188
get() = preferences.getString(KEY_HUGGING_FACE_API_KEY, "") ?: ""
8289
set(value) = preferences.edit()
@@ -142,6 +149,7 @@ class PreferenceManagerImpl(
142149
private const val KEY_FORM_ALWAYS_SHOW_ADVANCED_OPTIONS = "key_always_show_advanced_options"
143150
private const val KEY_SERVER_SOURCE = "key_server_source"
144151
private const val KEY_HORDE_API_KEY = "key_horde_api_key"
152+
private const val KEY_OPEN_AI_API_KEY = "key_open_ai_api_key"
145153
private const val KEY_HUGGING_FACE_API_KEY = "key_hugging_face_api_key"
146154
private const val KEY_HUGGING_FACE_MODEL_KEY = "key_hugging_face_model_key"
147155
private const val KEY_LOCAL_NN_API = "key_local_nn_api"

data/src/main/java/com/shifthackz/aisdv1/data/remote/HuggingFaceGenerationRemoteDataSource.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ internal class HuggingFaceGenerationRemoteDataSource(
2828

2929
override fun textToImage(
3030
modelName: String,
31-
payload: TextToImagePayload
31+
payload: TextToImagePayload,
3232
): Single<AiGenerationResult> = huggingFaceInferenceApi
3333
.generate(modelName, payload.mapToHuggingFaceRequest())
3434
.map(BitmapToBase64Converter::Input)
@@ -39,7 +39,7 @@ internal class HuggingFaceGenerationRemoteDataSource(
3939

4040
override fun imageToImage(
4141
modelName: String,
42-
payload: ImageToImagePayload
42+
payload: ImageToImagePayload,
4343
): Single<AiGenerationResult> = huggingFaceInferenceApi
4444
.generate(modelName, payload.mapToHuggingFaceRequest())
4545
.map(BitmapToBase64Converter::Input)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.shifthackz.aisdv1.data.remote
2+
3+
import com.shifthackz.aisdv1.core.common.log.errorLog
4+
import com.shifthackz.aisdv1.data.mappers.mapCloudToAiGenResult
5+
import com.shifthackz.aisdv1.data.mappers.mapToOpenAiRequest
6+
import com.shifthackz.aisdv1.domain.datasource.OpenAiGenerationDataSource
7+
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
8+
import com.shifthackz.aisdv1.network.api.openai.OpenAiApi
9+
import io.reactivex.rxjava3.core.Single
10+
import java.lang.IllegalStateException
11+
12+
internal class OpenAiGenerationRemoteDataSource(
13+
private val api: OpenAiApi,
14+
) : OpenAiGenerationDataSource.Remote {
15+
16+
override fun validateApiKey() = api
17+
.validateBearerToken()
18+
.andThen(Single.just(true))
19+
.onErrorResumeNext { t ->
20+
errorLog(t)
21+
Single.just(false)
22+
}
23+
24+
override fun textToImage(payload: TextToImagePayload) = payload
25+
.mapToOpenAiRequest()
26+
.let(api::generateImage)
27+
.flatMap { response ->
28+
response.data?.firstOrNull()?.b64json?.let { base64 ->
29+
Single.just(payload to base64)
30+
} ?: run {
31+
Single.error(IllegalStateException("Got null data object from API."))
32+
}
33+
}
34+
.map(Pair<TextToImagePayload, String>::mapCloudToAiGenResult)
35+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package com.shifthackz.aisdv1.data.repository
2+
3+
import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter
4+
import com.shifthackz.aisdv1.data.core.CoreGenerationRepository
5+
import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource
6+
import com.shifthackz.aisdv1.domain.datasource.OpenAiGenerationDataSource
7+
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
8+
import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway
9+
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
10+
import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository
11+
import io.reactivex.rxjava3.core.Single
12+
13+
internal class OpenAiGenerationRepositoryImpl(
14+
mediaStoreGateway: MediaStoreGateway,
15+
base64ToBitmapConverter: Base64ToBitmapConverter,
16+
localDataSource: GenerationResultDataSource.Local,
17+
preferenceManager: PreferenceManager,
18+
private val remoteDataSource: OpenAiGenerationDataSource.Remote,
19+
) : CoreGenerationRepository(
20+
mediaStoreGateway,
21+
base64ToBitmapConverter,
22+
localDataSource,
23+
preferenceManager,
24+
), OpenAiGenerationRepository {
25+
26+
override fun validateApiKey() = remoteDataSource.validateApiKey()
27+
28+
override fun generateFromText(payload: TextToImagePayload) = remoteDataSource
29+
.textToImage(payload)
30+
.flatMap(::insertGenerationResult)
31+
}

0 commit comments

Comments
 (0)