Skip to content

Commit cd48a13

Browse files
committed
add custom encoder for jvm
1 parent 9aab7b8 commit cd48a13

File tree

3 files changed

+210
-5
lines changed

3 files changed

+210
-5
lines changed

firebase-common/src/androidMain/kotlin/dev/teamhub/firebase/decoders.kt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,13 @@ package dev.teamhub.firebase
33
import kotlinx.serialization.*
44
import kotlinx.serialization.CompositeDecoder.Companion.READ_ALL
55
import kotlinx.serialization.CompositeDecoder.Companion.READ_DONE
6+
import kotlinx.serialization.internal.UnitDescriptor
67
import kotlinx.serialization.internal.nullable
78
import kotlinx.serialization.modules.EmptyModule
89
import kotlinx.serialization.modules.SerialModule
910
import kotlinx.serialization.modules.getContextualOrDefault
1011
import kotlin.reflect.KClass
1112

12-
@Suppress("UNCHECKED_CAST")
13-
inline fun <reified T> encode(strategy: SerializationStrategy<T>? = null /*value?.let { EmptyModule.getContextualOrDefault(it::class as KClass<*>) } as SerializationStrategy<T>*/, value: T) =
14-
value as Any?
15-
1613
@Suppress("UNCHECKED_CAST")
1714
inline fun <reified T> decode(strategy: DeserializationStrategy<T> = EmptyModule.getContextualOrDefault(T::class as KClass<Any>).run { if(null is T) nullable else this } as DeserializationStrategy<T>, value: Any?): T {
1815
require(value != null || strategy.descriptor.isNullable) { "Value was null for non-nullable type ${T::class}" }
@@ -75,7 +72,7 @@ class FirebaseClassDecoder(private val map: Map<String, Any?>) : FirebaseComposi
7572
?: READ_DONE
7673
}
7774

78-
open class FirebaseCompositeDecoder(
75+
open class FirebaseCompositeDecoder protected constructor(
7976
private val size: Int,
8077
private val get: (desc: SerialDescriptor, index: Int) -> Any?
8178
): CompositeDecoder {
@@ -95,6 +92,10 @@ open class FirebaseCompositeDecoder(
9592
override fun <T : Any> decodeNullableSerializableElement(desc: SerialDescriptor, index: Int, deserializer: DeserializationStrategy<T?>): T? =
9693
if(decodeNotNullMark(get(desc, index))) decodeSerializableElement(desc, index, deserializer) else decodeNull(get(desc, index))
9794

95+
fun decodeNullableSerializableElement(index: Int): Any? = get(UnitDescriptor, index)?.let { value ->
96+
value.firebaseSerializer().let { decodeSerializableElement(it.descriptor, index, it) }
97+
}
98+
9899
override fun <T> updateSerializableElement(desc: SerialDescriptor, index: Int, deserializer: DeserializationStrategy<T>, old: T): T =
99100
throw UpdateNotSupportedException(deserializer.descriptor.name)
100101

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package dev.teamhub.firebase
2+
3+
import kotlinx.serialization.*
4+
import kotlinx.serialization.modules.EmptyModule
5+
6+
fun <T> encode(strategy: SerializationStrategy<T> , value: T): Any? = FirebaseEncoder().apply { encode(strategy, value) }.value
7+
8+
@ImplicitReflectionSerializer
9+
fun encode(value: Any?): Any? = value?.let {
10+
FirebaseEncoder().apply { encode(it.firebaseSerializer(), it) }.value
11+
}
12+
13+
class FirebaseEncoder : Encoder {
14+
15+
var value: Any? = null
16+
17+
override val context = EmptyModule
18+
19+
@Suppress("UNCHECKED_CAST")
20+
override fun beginStructure(desc: SerialDescriptor, vararg typeParams: KSerializer<*>) = when(desc.kind as StructureKind) {
21+
StructureKind.LIST -> MutableList<Any?>(desc.elementsCount) { null }
22+
.also { value = it }
23+
.let { FirebaseCompositeEncoder { _, index, value -> it[index] = value } }
24+
StructureKind.MAP, StructureKind.CLASS -> mutableMapOf<Any?, Any?>()
25+
.also { value = it }
26+
.let { FirebaseCompositeEncoder { _, index, value -> it[desc.getElementName(index)] = value } }
27+
}
28+
29+
override fun encodeBoolean(value: Boolean) {
30+
this.value = value
31+
}
32+
33+
override fun encodeByte(value: Byte) {
34+
this.value = value
35+
}
36+
37+
override fun encodeChar(value: Char) {
38+
this.value = value
39+
}
40+
41+
override fun encodeDouble(value: Double) {
42+
this.value = value
43+
}
44+
45+
override fun encodeEnum(enumDescription: SerialDescriptor, ordinal: Int) {
46+
this.value = enumDescription.getElementName(ordinal)
47+
}
48+
49+
override fun encodeFloat(value: Float) {
50+
this.value = value
51+
}
52+
53+
override fun encodeInt(value: Int) {
54+
this.value = value
55+
}
56+
57+
override fun encodeLong(value: Long) {
58+
this.value = value
59+
}
60+
61+
override fun encodeNotNullMark() {
62+
//no-op
63+
}
64+
65+
override fun encodeNull() {
66+
this.value = null
67+
}
68+
69+
override fun encodeShort(value: Short) {
70+
this.value = value
71+
}
72+
73+
override fun encodeString(value: String) {
74+
this.value = value
75+
}
76+
77+
override fun encodeUnit() {
78+
this.value = Unit
79+
}
80+
81+
}
82+
83+
open class FirebaseCompositeEncoder(
84+
private val set: (desc: SerialDescriptor, index: Int, value: Any?) -> Unit
85+
): CompositeEncoder {
86+
87+
override val context = EmptyModule
88+
89+
override fun <T : Any> encodeNullableSerializableElement(desc: SerialDescriptor, index: Int, serializer: SerializationStrategy<T>, value: T?) =
90+
set(desc, index, value?.let { FirebaseEncoder().apply { encode(serializer, value) }.value })
91+
92+
override fun <T> encodeSerializableElement(desc: SerialDescriptor, index: Int, serializer: SerializationStrategy<T>, value: T) =
93+
set(desc, index, FirebaseEncoder().apply { encode(serializer, value) }.value)
94+
95+
override fun encodeNonSerializableElement(desc: SerialDescriptor, index: Int, value: Any) = set(desc, index, value)
96+
97+
override fun encodeBooleanElement(desc: SerialDescriptor, index: Int, value: Boolean) = set(desc, index, value)
98+
99+
override fun encodeByteElement(desc: SerialDescriptor, index: Int, value: Byte) = set(desc, index, value)
100+
101+
override fun encodeCharElement(desc: SerialDescriptor, index: Int, value: Char) = set(desc, index, value)
102+
103+
override fun encodeDoubleElement(desc: SerialDescriptor, index: Int, value: Double) = set(desc, index, value)
104+
105+
override fun encodeFloatElement(desc: SerialDescriptor, index: Int, value: Float) = set(desc, index, value)
106+
107+
override fun encodeIntElement(desc: SerialDescriptor, index: Int, value: Int) = set(desc, index, value)
108+
109+
override fun encodeLongElement(desc: SerialDescriptor, index: Int, value: Long) = set(desc, index, value)
110+
111+
override fun encodeShortElement(desc: SerialDescriptor, index: Int, value: Short) = set(desc, index, value)
112+
113+
override fun encodeStringElement(desc: SerialDescriptor, index: Int, value: String) = set(desc, index, value)
114+
115+
override fun encodeUnitElement(desc: SerialDescriptor, index: Int) = set(desc, index, Unit)
116+
117+
}
118+
119+
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package dev.teamhub.firebase
2+
3+
import kotlinx.serialization.*
4+
import kotlinx.serialization.internal.UnitSerializer
5+
import kotlinx.serialization.internal.defaultSerializer
6+
import kotlinx.serialization.internal.nullable
7+
8+
@Suppress("UNCHECKED_CAST")
9+
fun Any.firebaseSerializer() = (this::class.compiledSerializer() ?: this::class.defaultSerializer() ?: when(this) {
10+
is Map<*, *> -> FirebaseMapSerializer()
11+
is List<*> -> FirebaseListSerializer()
12+
else -> throw SerializationException("Can't locate argument-less serializer for $this. For generic classes, such as lists, please provide serializer explicitly.")
13+
}) as KSerializer<Any>
14+
15+
class FirebaseMapSerializer : KSerializer<Map<String, Any?>> {
16+
17+
lateinit var keys: List<String>
18+
lateinit var map: Map<String, Any?>
19+
20+
override val descriptor= object : SerialDescriptor {
21+
override val kind = StructureKind.MAP
22+
override val name = "kotlin.Map<String, Any>"
23+
override fun getElementIndex(name: String) = keys.indexOf(name)
24+
override fun getElementName(index: Int) = keys[index]
25+
override val elementsCount get() = map.size
26+
}
27+
28+
override fun serialize(encoder: Encoder, obj: Map<String, Any?>) {
29+
map = obj
30+
keys = obj.keys.toList()
31+
val collectionEncoder = encoder.beginCollection(descriptor, obj.size)
32+
keys.forEachIndexed { index, key ->
33+
val value = map.getValue(key)
34+
val serializer = value?.firebaseSerializer() ?: UnitSerializer.nullable as KSerializer<Any>
35+
collectionEncoder.encodeNullableSerializableElement(
36+
serializer.descriptor, index, serializer, value
37+
)
38+
}
39+
}
40+
41+
override fun deserialize(decoder: Decoder): Map<String, Any?> {
42+
val collectionDecoder = decoder.beginStructure(descriptor) as FirebaseCompositeDecoder
43+
val map = mutableMapOf<String, Any?>()
44+
for(index in 0 until collectionDecoder.decodeCollectionSize(descriptor) * 2 step 2) {
45+
map[collectionDecoder.decodeNullableSerializableElement(index) as String] =
46+
collectionDecoder.decodeNullableSerializableElement(index + 1)
47+
}
48+
return map
49+
}
50+
}
51+
52+
class FirebaseListSerializer : KSerializer<List<Any?>> {
53+
54+
lateinit var list: List<Any?>
55+
56+
override val descriptor= object : SerialDescriptor {
57+
override val kind = StructureKind.LIST
58+
override val name = "kotlin.List<Any>"
59+
override fun getElementIndex(name: String) = throw NotImplementedError()
60+
override fun getElementName(index: Int) = throw NotImplementedError()
61+
override val elementsCount get() = list.size
62+
}
63+
64+
@Suppress("UNCHECKED_CAST")
65+
override fun serialize(encoder: Encoder, obj: List<Any?>) {
66+
list = obj
67+
val collectionEncoder = encoder.beginCollection(descriptor, obj.size)
68+
list.forEachIndexed { index, value ->
69+
val serializer = value?.let { it::class.firebaseSerializer() } ?: UnitSerializer.nullable as KSerializer<Any>
70+
collectionEncoder.encodeNullableSerializableElement(
71+
serializer.descriptor, index, serializer, value
72+
)
73+
}
74+
}
75+
76+
override fun deserialize(decoder: Decoder): List<Any?> {
77+
val collectionDecoder = decoder.beginStructure(descriptor) as FirebaseCompositeDecoder
78+
val list = MutableList<Any?>(collectionDecoder.decodeCollectionSize(descriptor)) { null }
79+
list.forEachIndexed { index, _ ->
80+
list[index] = collectionDecoder.decodeNullableSerializableElement(index)
81+
}
82+
return list
83+
}
84+
}
85+

0 commit comments

Comments
 (0)