Skip to content

Commit 4a62ea2

Browse files
authored
HuggingFace provider implementation (#133)
* HuggingFace provider implementation | Patch 1 * HuggingFace provider implementation | Patch 2
1 parent fb9be4e commit 4a62ea2

File tree

91 files changed

+1493
-429
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+1493
-429
lines changed

app/build.gradle

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@ android {
1717
versionName "0.5.7"
1818
versionCode 171
1919

20-
buildConfigField "String", "IMAGE_CDN_URL", "\"https://random.imagecdn.app\""
21-
buildConfigField "String", "HORDE_AI_URL", "\"https://stablehorde.net\""
20+
buildConfigField "String", "IMAGE_CDN_URL", "\"https://random.imagecdn.app/\""
21+
buildConfigField "String", "HUGGING_FACE_URL", "\"https://huggingface.co/\""
22+
buildConfigField "String", "HUGGING_FACE_INFERENCE_URL", "\"https://api-inference.huggingface.co/\""
23+
buildConfigField "String", "HORDE_AI_URL", "\"https://stablehorde.net/\""
24+
2225
buildConfigField "String", "HORDE_AI_SIGN_UP_URL", "\"https://stablehorde.net/register\""
26+
buildConfigField "String", "HUGGING_FACE_INFO_URL", "\"https://huggingface.co/docs/api-inference/index\""
2327
buildConfigField "String", "UPDATE_API_URL", "\"https://sdai.moroz.cc\""
2428
buildConfigField "String", "DEMO_MODE_API_URL", "\"https://sdai.moroz.cc\""
2529
buildConfigField "String", "POLICY_URL", "\"https://sdai.moroz.cc/policy.html\""

app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import com.shifthackz.aisdv1.core.common.appbuild.BuildVersion
77
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
88
import com.shifthackz.aisdv1.core.common.links.LinksProvider
99
import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider
10+
import com.shifthackz.aisdv1.domain.entity.ServerSource
1011
import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials
1112
import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationStore
1213
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
@@ -15,7 +16,9 @@ import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvid
1516
import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider
1617
import com.shifthackz.aisdv1.network.qualifiers.ApiUrlProvider
1718
import com.shifthackz.aisdv1.network.qualifiers.CredentialsProvider
18-
import com.shifthackz.aisdv1.network.qualifiers.HordeApiKeyProvider
19+
import com.shifthackz.aisdv1.network.qualifiers.ApiKeyProvider
20+
import com.shifthackz.aisdv1.network.qualifiers.NetworkHeaders
21+
import com.shifthackz.aisdv1.network.qualifiers.NetworkPrefixes
1922
import io.reactivex.rxjava3.android.schedulers.AndroidSchedulers
2023
import io.reactivex.rxjava3.core.Scheduler
2124
import io.reactivex.rxjava3.schedulers.Schedulers
@@ -28,6 +31,7 @@ import java.util.concurrent.Executors
2831
* Needed for retrofit builder, because it will crash at runtime if baseUrl is not set
2932
*/
3033
private const val DEFAULT_SERVER_URL = "http://127.0.0.1"
34+
private const val DEFAULT_HORDE_API_KEY = "0000000000"
3135

3236
val providersModule = module {
3337

@@ -37,11 +41,26 @@ val providersModule = module {
3741
override val stableDiffusionAppApiUrl: String = BuildConfig.UPDATE_API_URL
3842
override val hordeApiUrl: String = BuildConfig.HORDE_AI_URL
3943
override val imageCdnApiUrl: String = BuildConfig.IMAGE_CDN_URL
44+
override val huggingFaceApiUrl: String = BuildConfig.HUGGING_FACE_URL
45+
override val huggingFaceInferenceApiUrl = BuildConfig.HUGGING_FACE_INFERENCE_URL
4046
}
4147
}
4248

4349
single {
44-
HordeApiKeyProvider { get<PreferenceManager>().hordeApiKey }
50+
ApiKeyProvider {
51+
val preference = get<PreferenceManager>()
52+
when (preference.source) {
53+
ServerSource.HORDE -> {
54+
val key = preference.hordeApiKey.takeIf(String::isNotEmpty) ?: DEFAULT_HORDE_API_KEY
55+
NetworkHeaders.API_KEY to key
56+
}
57+
ServerSource.HUGGING_FACE -> {
58+
val key = "${NetworkPrefixes.BEARER} ${preference.huggingFaceApiKey}"
59+
NetworkHeaders.AUTHORIZATION to key
60+
}
61+
else -> null
62+
}
63+
}
4564
}
4665

4766
single<CredentialsProvider> {
@@ -63,6 +82,7 @@ val providersModule = module {
6382
object : LinksProvider {
6483
override val hordeUrl: String = BuildConfig.HORDE_AI_URL
6584
override val hordeSignUpUrl: String = BuildConfig.HORDE_AI_SIGN_UP_URL
85+
override val huggingFaceUrl: String = BuildConfig.HUGGING_FACE_INFO_URL
6686
override val privacyPolicyUrl: String = BuildConfig.POLICY_URL
6787
override val gitHubSourceUrl: String = BuildConfig.GITHUB_SOURCE_URL
6888
override val setupInstructionsUrl: String = BuildConfig.SETUP_INSTRUCTIONS_URL
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package com.shifthackz.aisdv1.core.common.extensions
2+
3+
inline fun <T> T.applyIf(predicate: Boolean, block: T.() -> Unit): T {
4+
if (!predicate) return this
5+
return apply(block)
6+
}

core/common/src/main/java/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.core.common.links
33
interface LinksProvider {
44
val hordeUrl: String
55
val hordeSignUpUrl: String
6+
val huggingFaceUrl: String
67
val privacyPolicyUrl: String
78
val gitHubSourceUrl: String
89
val setupInstructionsUrl: String

core/ui/src/main/java/com/shifthackz/aisdv1/core/contract/RxDisposableContract.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ interface RxDisposableContract {
99

1010
infix operator fun CompositeDisposable.plus(d: Disposable) = this.add(compositeDisposable)
1111

12-
operator fun Disposable.not() {
12+
operator fun Disposable.not(): Disposable {
1313
compositeDisposable.add(this)
14+
return this
1415
}
1516

16-
fun Disposable.addToDisposable() {
17+
fun Disposable.addToDisposable(): Disposable {
1718
compositeDisposable.add(this)
19+
return this
1820
}
1921
}

data/src/main/java/com/shifthackz/aisdv1/data/di/LocalDataSourceModule.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.data.gateway.DatabaseClearGatewayImpl
44
import com.shifthackz.aisdv1.data.gateway.mediastore.MediaStoreGatewayFactory
55
import com.shifthackz.aisdv1.data.local.DownloadableModelLocalDataSource
66
import com.shifthackz.aisdv1.data.local.GenerationResultLocalDataSource
7+
import com.shifthackz.aisdv1.data.local.HuggingFaceModelsLocalDataSource
78
import com.shifthackz.aisdv1.data.local.ServerConfigurationLocalDataSource
89
import com.shifthackz.aisdv1.data.local.StableDiffusionEmbeddingsLocalDataSource
910
import com.shifthackz.aisdv1.data.local.StableDiffusionHyperNetworksLocalDataSource
@@ -12,6 +13,7 @@ import com.shifthackz.aisdv1.data.local.StableDiffusionModelsLocalDataSource
1213
import com.shifthackz.aisdv1.data.local.StableDiffusionSamplersLocalDataSource
1314
import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
1415
import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource
16+
import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource
1517
import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource
1618
import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource
1719
import com.shifthackz.aisdv1.domain.datasource.StableDiffusionHyperNetworksDataSource
@@ -35,5 +37,6 @@ val localDataSourceModule = module {
3537
factoryOf(::ServerConfigurationLocalDataSource) bind ServerConfigurationDataSource.Local::class
3638
factoryOf(::GenerationResultLocalDataSource) bind GenerationResultDataSource.Local::class
3739
factoryOf(::DownloadableModelLocalDataSource) bind DownloadableModelDataSource.Local::class
40+
factoryOf(::HuggingFaceModelsLocalDataSource) bind HuggingFaceModelsDataSource.Local::class
3841
factory { MediaStoreGatewayFactory(androidContext(), get()).invoke() }
3942
}

data/src/main/java/com/shifthackz/aisdv1/data/di/RemoteDataSourceModule.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import com.shifthackz.aisdv1.data.provider.ServerUrlProvider
66
import com.shifthackz.aisdv1.data.remote.DownloadableModelRemoteDataSource
77
import com.shifthackz.aisdv1.data.remote.HordeGenerationRemoteDataSource
88
import com.shifthackz.aisdv1.data.remote.HordeStatusSource
9+
import com.shifthackz.aisdv1.data.remote.HuggingFaceGenerationRemoteDataSource
10+
import com.shifthackz.aisdv1.data.remote.HuggingFaceModelsRemoteDataSource
911
import com.shifthackz.aisdv1.data.remote.RandomImageRemoteDataSource
1012
import com.shifthackz.aisdv1.data.remote.ServerConfigurationRemoteDataSource
1113
import com.shifthackz.aisdv1.data.remote.StableDiffusionEmbeddingsRemoteDataSource
@@ -16,6 +18,8 @@ import com.shifthackz.aisdv1.data.remote.StableDiffusionModelsRemoteDataSource
1618
import com.shifthackz.aisdv1.data.remote.StableDiffusionSamplersRemoteDataSource
1719
import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
1820
import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource
21+
import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource
22+
import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource
1923
import com.shifthackz.aisdv1.domain.datasource.RandomImageDataSource
2024
import com.shifthackz.aisdv1.domain.datasource.ServerConfigurationDataSource
2125
import com.shifthackz.aisdv1.domain.datasource.StableDiffusionEmbeddingsDataSource
@@ -47,6 +51,7 @@ val remoteDataSourceModule = module {
4751
}
4852
singleOf(::HordeStatusSource) bind HordeGenerationDataSource.StatusSource::class
4953
factoryOf(::HordeGenerationRemoteDataSource) bind HordeGenerationDataSource.Remote::class
54+
factoryOf(::HuggingFaceGenerationRemoteDataSource) bind HuggingFaceGenerationDataSource.Remote::class
5055
factoryOf(::StableDiffusionGenerationRemoteDataSource) bind StableDiffusionGenerationDataSource.Remote::class
5156
factoryOf(::StableDiffusionSamplersRemoteDataSource) bind StableDiffusionSamplersDataSource.Remote::class
5257
factoryOf(::StableDiffusionModelsRemoteDataSource) bind StableDiffusionModelsDataSource.Remote::class
@@ -56,11 +61,12 @@ val remoteDataSourceModule = module {
5661
factoryOf(::ServerConfigurationRemoteDataSource) bind ServerConfigurationDataSource.Remote::class
5762
factoryOf(::RandomImageRemoteDataSource) bind RandomImageDataSource.Remote::class
5863
factoryOf(::DownloadableModelRemoteDataSource) bind DownloadableModelDataSource.Remote::class
64+
factoryOf(::HuggingFaceModelsRemoteDataSource) bind HuggingFaceModelsDataSource.Remote::class
5965

6066
factory<ServerConnectivityGateway> {
6167
val lambda: () -> Boolean = {
6268
val prefs = get<PreferenceManager>()
63-
prefs.source == ServerSource.CUSTOM
69+
prefs.source == ServerSource.AUTOMATIC1111
6470
}
6571
val monitor = get<ConnectivityMonitor> { parametersOf(lambda) }
6672
ServerConnectivityGatewayImpl(monitor, get())

data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import android.os.PowerManager
55
import com.shifthackz.aisdv1.data.repository.DownloadableModelRepositoryImpl
66
import com.shifthackz.aisdv1.data.repository.GenerationResultRepositoryImpl
77
import com.shifthackz.aisdv1.data.repository.HordeGenerationRepositoryImpl
8+
import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl
9+
import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl
810
import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl
911
import com.shifthackz.aisdv1.data.repository.RandomImageRepositoryImpl
1012
import com.shifthackz.aisdv1.data.repository.ServerConfigurationRepositoryImpl
@@ -19,6 +21,8 @@ import com.shifthackz.aisdv1.data.repository.WakeLockRepositoryImpl
1921
import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository
2022
import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository
2123
import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository
24+
import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository
25+
import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository
2226
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository
2327
import com.shifthackz.aisdv1.domain.repository.RandomImageRepository
2428
import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository
@@ -46,6 +50,7 @@ val repositoryModule = module {
4650
singleOf(::TemporaryGenerationResultRepositoryImpl) bind TemporaryGenerationResultRepository::class
4751
factoryOf(::LocalDiffusionGenerationRepositoryImpl) bind LocalDiffusionGenerationRepository::class
4852
factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class
53+
factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class
4954
factoryOf(::StableDiffusionGenerationRepositoryImpl) bind StableDiffusionGenerationRepository::class
5055
factoryOf(::StableDiffusionModelsRepositoryImpl) bind StableDiffusionModelsRepository::class
5156
factoryOf(::StableDiffusionSamplersRepositoryImpl) bind StableDiffusionSamplersRepository::class
@@ -56,4 +61,5 @@ val repositoryModule = module {
5661
factoryOf(::GenerationResultRepositoryImpl) bind GenerationResultRepository::class
5762
factoryOf(::RandomImageRepositoryImpl) bind RandomImageRepository::class
5863
factoryOf(::DownloadableModelRepositoryImpl) bind DownloadableModelRepository::class
64+
factoryOf(::HuggingFaceModelsRepositoryImpl) bind HuggingFaceModelsRepository::class
5965
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.shifthackz.aisdv1.data.local
2+
3+
import com.shifthackz.aisdv1.data.mappers.mapDomainToEntity
4+
import com.shifthackz.aisdv1.data.mappers.mapEntityToDomain
5+
import com.shifthackz.aisdv1.domain.datasource.HuggingFaceModelsDataSource
6+
import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel
7+
import com.shifthackz.aisdv1.storage.db.persistent.dao.HuggingFaceModelDao
8+
import com.shifthackz.aisdv1.storage.db.persistent.entity.HuggingFaceModelEntity
9+
10+
internal class HuggingFaceModelsLocalDataSource(
11+
private val dao: HuggingFaceModelDao,
12+
) : HuggingFaceModelsDataSource.Local {
13+
14+
override fun getAll() = dao
15+
.query()
16+
.map(List<HuggingFaceModelEntity>::mapEntityToDomain)
17+
18+
override fun save(models: List<HuggingFaceModel>) = dao
19+
.deleteAll()
20+
.andThen(dao.insertList(models.mapDomainToEntity()))
21+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.shifthackz.aisdv1.data.mappers
2+
3+
import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel
4+
import com.shifthackz.aisdv1.network.model.HuggingFaceModelRaw
5+
import com.shifthackz.aisdv1.storage.db.persistent.entity.HuggingFaceModelEntity
6+
7+
//region RAW --> DOMAIN
8+
fun List<HuggingFaceModelRaw>.mapRawToDomain(): List<HuggingFaceModel> =
9+
map(HuggingFaceModelRaw::mapRawToDomain)
10+
11+
fun HuggingFaceModelRaw.mapRawToDomain(): HuggingFaceModel = with(this) {
12+
HuggingFaceModel(
13+
id = id ?: "",
14+
name = name ?: "",
15+
alias = alias ?: "",
16+
source = source ?: ""
17+
)
18+
}
19+
//endregion
20+
21+
//region DOMAIN -> ENTITY
22+
fun List<HuggingFaceModel>.mapDomainToEntity(): List<HuggingFaceModelEntity> =
23+
map(HuggingFaceModel::mapDomainToEntity)
24+
25+
fun HuggingFaceModel.mapDomainToEntity(): HuggingFaceModelEntity = with(this) {
26+
HuggingFaceModelEntity(id, name, alias, source)
27+
}
28+
//endregion
29+
30+
//region ENTITY -> DOMAIN
31+
fun List<HuggingFaceModelEntity>.mapEntityToDomain(): List<HuggingFaceModel> =
32+
map(HuggingFaceModelEntity::mapEntityToDomain)
33+
34+
fun HuggingFaceModelEntity.mapEntityToDomain(): HuggingFaceModel = with(this) {
35+
HuggingFaceModel(id, name, alias, source)
36+
}
37+
//endregion

0 commit comments

Comments
 (0)