Skip to content

Commit 93a06df

Browse files
authored
Do not use tree-based decoding for fast-path polymorphism (#1919)
Do not use tree-based decoding for fast-path polymorphism and try to optimistically read it as very first key and then silently skip Fixes #1839
1 parent bb18d62 commit 93a06df

File tree

15 files changed

+233
-28
lines changed

15 files changed

+233
-28
lines changed

benchmark/build.gradle

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ apply plugin: 'java'
66
apply plugin: 'kotlin'
77
apply plugin: 'kotlinx-serialization'
88
apply plugin: 'idea'
9-
apply plugin: 'net.ltgt.apt'
109
apply plugin: 'com.github.johnrengelman.shadow'
11-
apply plugin: 'me.champeau.gradle.jmh'
10+
apply plugin: 'me.champeau.jmh'
1211

1312
sourceCompatibility = 1.8
1413
targetCompatibility = 1.8
15-
jmh.jmhVersion = 1.22
14+
jmh.jmhVersion = "1.22"
1615

1716
jmhJar {
1817
baseName 'benchmarks'
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package kotlinx.benchmarks.json
2+
3+
import kotlinx.serialization.*
4+
import kotlinx.serialization.json.*
5+
import kotlinx.serialization.modules.*
6+
import org.openjdk.jmh.annotations.*
7+
import java.util.concurrent.*
8+
9+
@Warmup(iterations = 7, time = 1)
10+
@Measurement(iterations = 5, time = 1)
11+
@BenchmarkMode(Mode.Throughput)
12+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
13+
@State(Scope.Benchmark)
14+
@Fork(1)
15+
open class PolymorphismOverheadBenchmark {
16+
17+
@Serializable
18+
@JsonClassDiscriminator("poly")
19+
data class PolymorphicWrapper(val i: @Polymorphic Poly, val i2: Impl) // amortize the cost a bit
20+
21+
@Serializable
22+
data class BaseWrapper(val i: Impl, val i2: Impl)
23+
24+
@JsonClassDiscriminator("poly")
25+
interface Poly
26+
27+
@Serializable
28+
@JsonClassDiscriminator("poly")
29+
class Impl(val a: Int, val b: String) : Poly
30+
31+
private val impl = Impl(239, "average_size_string")
32+
private val module = SerializersModule {
33+
polymorphic(Poly::class) {
34+
subclass(Impl.serializer())
35+
}
36+
}
37+
38+
private val json = Json { serializersModule = module }
39+
private val implString = json.encodeToString(impl)
40+
private val polyString = json.encodeToString<Poly>(impl)
41+
private val serializer = serializer<Poly>()
42+
43+
// 5000
44+
@Benchmark
45+
fun base() = json.decodeFromString(Impl.serializer(), implString)
46+
47+
// As of 1.3.x
48+
// Baseline -- 1500
49+
// v1, no skip -- 2000
50+
// v2, with skip -- 3000 [withdrawn]
51+
@Benchmark
52+
fun poly() = json.decodeFromString(serializer, polyString)
53+
54+
}

build.gradle

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ buildscript {
7474

7575
// Various benchmarking stuff
7676
classpath "com.github.jengelman.gradle.plugins:shadow:4.0.2"
77-
classpath "me.champeau.gradle:jmh-gradle-plugin:0.5.3"
78-
classpath "net.ltgt.gradle:gradle-apt-plugin:0.21"
77+
classpath "me.champeau.jmh:jmh-gradle-plugin:0.6.6"
7978
}
8079
}
8180

formats/json/commonMain/src/kotlinx/serialization/json/Json.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public sealed class Json(
9696
*/
9797
public final override fun <T> decodeFromString(deserializer: DeserializationStrategy<T>, string: String): T {
9898
val lexer = StringJsonLexer(string)
99-
val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor)
99+
val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null)
100100
val result = input.decodeSerializableValue(deserializer)
101101
lexer.expectEof()
102102
return result

formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ internal class JsonPath {
2424

2525
// Tombstone indicates that we are within a map, but the map key is currently being decoded.
2626
// It is also used to overwrite a previous map key to avoid memory leaks and misattribution.
27-
object Tombstone
27+
private object Tombstone
2828

2929
/*
3030
* Serial descriptor, map key or the tombstone for map key

formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import kotlinx.serialization.*
99
import kotlinx.serialization.descriptors.*
1010
import kotlinx.serialization.internal.*
1111
import kotlinx.serialization.json.*
12+
import kotlin.jvm.*
1213

1314
@Suppress("UNCHECKED_CAST")
1415
internal inline fun <T> JsonEncoder.encodePolymorphically(
@@ -55,12 +56,13 @@ internal fun checkKind(kind: SerialKind) {
5556
}
5657

5758
internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: DeserializationStrategy<T>): T {
59+
// NB: changes in this method should be reflected in StreamingJsonDecoder#decodeSerializableValue
5860
if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) {
5961
return deserializer.deserialize(this)
6062
}
63+
val discriminator = deserializer.descriptor.classDiscriminator(json)
6164

6265
val jsonTree = cast<JsonObject>(decodeJsonElement(), deserializer.descriptor)
63-
val discriminator = deserializer.descriptor.classDiscriminator(json)
6466
val type = jsonTree[discriminator]?.jsonPrimitive?.content
6567
val actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type)
6668
?: throwSerializerNotFound(type, jsonTree)
@@ -69,7 +71,8 @@ internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: De
6971
return json.readPolymorphicJson(discriminator, jsonTree, actualSerializer as DeserializationStrategy<T>)
7072
}
7173

72-
private fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing {
74+
@JvmName("throwSerializerNotFound")
75+
internal fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing {
7376
val suffix =
7477
if (type == null) "missing class discriminator ('null')"
7578
else "class discriminator '$type'"

formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import kotlinx.serialization.descriptors.*
99
import kotlinx.serialization.encoding.*
1010
import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE
1111
import kotlinx.serialization.encoding.CompositeDecoder.Companion.UNKNOWN_NAME
12+
import kotlinx.serialization.internal.*
1213
import kotlinx.serialization.json.*
1314
import kotlinx.serialization.modules.*
1415
import kotlin.jvm.*
@@ -21,11 +22,27 @@ internal open class StreamingJsonDecoder(
2122
final override val json: Json,
2223
private val mode: WriteMode,
2324
@JvmField internal val lexer: AbstractJsonLexer,
24-
descriptor: SerialDescriptor
25+
descriptor: SerialDescriptor,
26+
discriminatorHolder: DiscriminatorHolder?
2527
) : JsonDecoder, AbstractDecoder() {
2628

29+
// A mutable reference to the discriminator that have to be skipped when in optimistic phase
30+
// of polymorphic serialization, see `decodeSerializableValue`
31+
internal class DiscriminatorHolder(@JvmField var discriminatorToSkip: String?)
32+
33+
private fun DiscriminatorHolder?.trySkip(unknownKey: String): Boolean {
34+
if (this == null) return false
35+
if (discriminatorToSkip == unknownKey) {
36+
discriminatorToSkip = null
37+
return true
38+
}
39+
return false
40+
}
41+
42+
2743
override val serializersModule: SerializersModule = json.serializersModule
2844
private var currentIndex = -1
45+
private var discriminatorHolder: DiscriminatorHolder? = discriminatorHolder
2946
private val configuration = json.configuration
3047

3148
private val elementMarker: JsonElementMarker? = if (configuration.explicitNulls) null else JsonElementMarker(descriptor)
@@ -35,7 +52,40 @@ internal open class StreamingJsonDecoder(
3552
@Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
3653
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
3754
try {
38-
return decodeSerializableValuePolymorphic(deserializer)
55+
/*
56+
* This is an optimized path over decodeSerializableValuePolymorphic(deserializer):
57+
* dSVP reads the very next JSON tree into a memory as JsonElement and then runs TreeJsonDecoder over it
58+
* in order to deal with an arbitrary order of keys, but with the price of additional memory pressure
59+
* and CPU consumption.
60+
* We would like to provide best possible performance for data produced by kotlinx.serialization
61+
* itself, for that we do the following optimistic optimization:
62+
*
63+
* 0) Remember current position in the string
64+
* 1) Read the very next key of JSON structure
65+
* 2) If it matches* the descriminator key, read the value, remember current position
66+
* 3) Return the value, recover an initial position
67+
* (*) -- if it doesn't match, fallback to dSVP method.
68+
*/
69+
if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) {
70+
return deserializer.deserialize(this)
71+
}
72+
73+
val discriminator = deserializer.descriptor.classDiscriminator(json)
74+
val type = lexer.consumeLeadingMatchingValue(discriminator, configuration.isLenient)
75+
var actualSerializer: DeserializationStrategy<out Any>? = null
76+
if (type != null) {
77+
actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type)
78+
}
79+
if (actualSerializer == null) {
80+
// Fallback if we haven't found discriminator or serializer
81+
return decodeSerializableValuePolymorphic<T>(deserializer as DeserializationStrategy<T>)
82+
}
83+
84+
discriminatorHolder = DiscriminatorHolder(discriminator)
85+
@Suppress("UNCHECKED_CAST")
86+
val result = actualSerializer.deserialize(this) as T
87+
return result
88+
3989
} catch (e: MissingFieldException) {
4090
throw MissingFieldException(e.message + " at path: " + lexer.path.getPath(), e)
4191
}
@@ -52,12 +102,13 @@ internal open class StreamingJsonDecoder(
52102
json,
53103
newMode,
54104
lexer,
55-
descriptor
105+
descriptor,
106+
discriminatorHolder
56107
)
57108
else -> if (mode == newMode && json.configuration.explicitNulls) {
58109
this
59110
} else {
60-
StreamingJsonDecoder(json, newMode, lexer, descriptor)
111+
StreamingJsonDecoder(json, newMode, lexer, descriptor, discriminatorHolder)
61112
}
62113
}
63114
}
@@ -193,7 +244,7 @@ internal open class StreamingJsonDecoder(
193244
}
194245

195246
private fun handleUnknown(key: String): Boolean {
196-
if (configuration.ignoreUnknownKeys) {
247+
if (configuration.ignoreUnknownKeys || discriminatorHolder.trySkip(key)) {
197248
lexer.skipElement(configuration.isLenient)
198249
} else {
199250
// Here we cannot properly update json path indicies

formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ internal abstract class AbstractJsonLexer {
283283
return current
284284
}
285285

286+
abstract fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String?
287+
286288
fun peekString(isLenient: Boolean): String? {
287289
val token = peekNextToken()
288290
val string = if (isLenient) {

formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer(
7878

7979
override fun consumeKeyString(): String {
8080
/*
81-
* For strings we assume that escaped symbols are rather an exception, so firstly
82-
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
83-
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
84-
*/
81+
* For strings we assume that escaped symbols are rather an exception, so firstly
82+
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
83+
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
84+
*/
8585
consumeNextToken(STRING)
8686
val current = currentPosition
8787
val closingQuote = source.indexOf('"', current)
@@ -96,4 +96,22 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer(
9696
this.currentPosition = closingQuote + 1
9797
return source.substring(current, closingQuote)
9898
}
99+
100+
override fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? {
101+
val positionSnapshot = currentPosition
102+
try {
103+
// Malformed JSON, bailout
104+
if (consumeNextToken() != TC_BEGIN_OBJ) return null
105+
val firstKey = if (isLenient) consumeKeyString() else consumeStringLenientNotNull()
106+
if (firstKey == keyToMatch) {
107+
if (consumeNextToken() != TC_COLON) return null
108+
val result = if (isLenient) consumeString() else consumeStringLenientNotNull()
109+
return result
110+
}
111+
return null
112+
} finally {
113+
// Restore the position
114+
currentPosition = positionSnapshot
115+
}
116+
}
99117
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright 2017-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
package kotlinx.serialization.features
5+
6+
import kotlinx.serialization.*
7+
import kotlinx.serialization.json.*
8+
import kotlinx.serialization.modules.*
9+
import kotlin.test.*
10+
11+
class DefaultPolymorphicSerializerTest : JsonTestBase() {
12+
13+
@Serializable
14+
abstract class Project {
15+
abstract val name: String
16+
}
17+
18+
@Serializable
19+
data class DefaultProject(override val name: String, val type: String): Project()
20+
21+
val module = SerializersModule {
22+
polymorphic(Project::class) {
23+
defaultDeserializer { DefaultProject.serializer() }
24+
}
25+
}
26+
27+
private val json = Json { serializersModule = module }
28+
29+
@Test
30+
fun test() = parametrizedTest {
31+
assertEquals(DefaultProject("example", "unknown"),
32+
json.decodeFromString<Project>(""" {"type":"unknown","name":"example"}""", it))
33+
}
34+
35+
}

0 commit comments

Comments
 (0)