Skip to content

Commit 652574e

Browse files
feat(client): make union deserialization more robust (#385)
feat(client): add enum validation method chore(client): remove unnecessary json state from some query param classes chore(internal): add json roundtripping tests chore(internal): add invalid json deserialization tests
1 parent 5fae8b7 commit 652574e

File tree

550 files changed

+33689
-2262
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

550 files changed

+33689
-2262
lines changed

openai-java-core/src/main/kotlin/com/openai/core/BaseDeserializer.kt

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@ import com.fasterxml.jackson.databind.BeanProperty
77
import com.fasterxml.jackson.databind.DeserializationContext
88
import com.fasterxml.jackson.databind.JavaType
99
import com.fasterxml.jackson.databind.JsonDeserializer
10-
import com.fasterxml.jackson.databind.JsonMappingException
1110
import com.fasterxml.jackson.databind.JsonNode
1211
import com.fasterxml.jackson.databind.deser.ContextualDeserializer
1312
import com.fasterxml.jackson.databind.deser.std.StdDeserializer
14-
import com.openai.errors.OpenAIInvalidDataException
1513
import kotlin.reflect.KClass
1614

1715
abstract class BaseDeserializer<T : Any>(type: KClass<T>) :
@@ -30,38 +28,17 @@ abstract class BaseDeserializer<T : Any>(type: KClass<T>) :
3028

3129
protected abstract fun ObjectCodec.deserialize(node: JsonNode): T
3230

33-
protected fun <T> ObjectCodec.deserialize(node: JsonNode, type: TypeReference<T>): T =
31+
protected fun <T> ObjectCodec.tryDeserialize(node: JsonNode, type: TypeReference<T>): T? =
3432
try {
3533
readValue(treeAsTokens(node), type)
3634
} catch (e: Exception) {
37-
throw OpenAIInvalidDataException("Error deserializing", e)
38-
}
39-
40-
protected fun <T> ObjectCodec.tryDeserialize(
41-
node: JsonNode,
42-
type: TypeReference<T>,
43-
validate: (T) -> Unit = {},
44-
): T? {
45-
return try {
46-
readValue(treeAsTokens(node), type).apply(validate)
47-
} catch (e: JsonMappingException) {
48-
null
49-
} catch (e: RuntimeException) {
5035
null
5136
}
52-
}
5337

54-
protected fun <T> ObjectCodec.tryDeserialize(
55-
node: JsonNode,
56-
type: JavaType,
57-
validate: (T) -> Unit = {},
58-
): T? {
59-
return try {
60-
readValue<T>(treeAsTokens(node), type).apply(validate)
61-
} catch (e: JsonMappingException) {
62-
null
63-
} catch (e: RuntimeException) {
38+
protected fun <T> ObjectCodec.tryDeserialize(node: JsonNode, type: JavaType): T? =
39+
try {
40+
readValue(treeAsTokens(node), type)
41+
} catch (e: Exception) {
6442
null
6543
}
66-
}
6744
}

openai-java-core/src/main/kotlin/com/openai/core/Utils.kt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,34 @@ internal fun <K : Comparable<K>, V> SortedMap<K, V>.toImmutable(): SortedMap<K,
2525
if (isEmpty()) Collections.emptySortedMap()
2626
else Collections.unmodifiableSortedMap(toSortedMap(comparator()))
2727

28+
/**
29+
* Returns all elements that yield the largest value for the given function, or an empty list if
30+
* there are zero elements.
31+
*
32+
* This is similar to [Sequence.maxByOrNull] except it returns _all_ elements that yield the largest
33+
* value; not just the first one.
34+
*/
35+
@JvmSynthetic
36+
internal fun <T, R : Comparable<R>> Sequence<T>.allMaxBy(selector: (T) -> R): List<T> {
37+
var maxValue: R? = null
38+
val maxElements = mutableListOf<T>()
39+
40+
val iterator = iterator()
41+
while (iterator.hasNext()) {
42+
val element = iterator.next()
43+
val value = selector(element)
44+
if (maxValue == null || value > maxValue) {
45+
maxValue = value
46+
maxElements.clear()
47+
maxElements.add(element)
48+
} else if (value == maxValue) {
49+
maxElements.add(element)
50+
}
51+
}
52+
53+
return maxElements
54+
}
55+
2856
/**
2957
* Returns whether [this] is equal to [other].
3058
*

openai-java-core/src/main/kotlin/com/openai/models/AllModels.kt

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import com.openai.core.BaseSerializer
1515
import com.openai.core.Enum
1616
import com.openai.core.JsonField
1717
import com.openai.core.JsonValue
18+
import com.openai.core.allMaxBy
1819
import com.openai.core.getOrThrow
1920
import com.openai.errors.OpenAIInvalidDataException
2021
import java.util.Objects
@@ -50,14 +51,13 @@ private constructor(
5051

5152
fun _json(): Optional<JsonValue> = Optional.ofNullable(_json)
5253

53-
fun <T> accept(visitor: Visitor<T>): T {
54-
return when {
54+
fun <T> accept(visitor: Visitor<T>): T =
55+
when {
5556
string != null -> visitor.visitString(string)
5657
chatModel != null -> visitor.visitChatModel(chatModel)
5758
unionMember2 != null -> visitor.visitUnionMember2(unionMember2)
5859
else -> visitor.unknown(_json)
5960
}
60-
}
6161

6262
private var validated: Boolean = false
6363

@@ -70,14 +70,45 @@ private constructor(
7070
object : Visitor<Unit> {
7171
override fun visitString(string: String) {}
7272

73-
override fun visitChatModel(chatModel: ChatModel) {}
73+
override fun visitChatModel(chatModel: ChatModel) {
74+
chatModel.validate()
75+
}
7476

75-
override fun visitUnionMember2(unionMember2: UnionMember2) {}
77+
override fun visitUnionMember2(unionMember2: UnionMember2) {
78+
unionMember2.validate()
79+
}
7680
}
7781
)
7882
validated = true
7983
}
8084

85+
fun isValid(): Boolean =
86+
try {
87+
validate()
88+
true
89+
} catch (e: OpenAIInvalidDataException) {
90+
false
91+
}
92+
93+
/**
94+
* Returns a score indicating how many valid values are contained in this object recursively.
95+
*
96+
* Used for best match union deserialization.
97+
*/
98+
@JvmSynthetic
99+
internal fun validity(): Int =
100+
accept(
101+
object : Visitor<Int> {
102+
override fun visitString(string: String) = 1
103+
104+
override fun visitChatModel(chatModel: ChatModel) = chatModel.validity()
105+
106+
override fun visitUnionMember2(unionMember2: UnionMember2) = unionMember2.validity()
107+
108+
override fun unknown(json: JsonValue?) = 0
109+
}
110+
)
111+
81112
override fun equals(other: Any?): Boolean {
82113
if (this === other) {
83114
return true
@@ -135,17 +166,30 @@ private constructor(
135166
override fun ObjectCodec.deserialize(node: JsonNode): AllModels {
136167
val json = JsonValue.fromJsonNode(node)
137168

138-
tryDeserialize(node, jacksonTypeRef<String>())?.let {
139-
return AllModels(string = it, _json = json)
140-
}
141-
tryDeserialize(node, jacksonTypeRef<ChatModel>())?.let {
142-
return AllModels(chatModel = it, _json = json)
143-
}
144-
tryDeserialize(node, jacksonTypeRef<UnionMember2>())?.let {
145-
return AllModels(unionMember2 = it, _json = json)
169+
val bestMatches =
170+
sequenceOf(
171+
tryDeserialize(node, jacksonTypeRef<ChatModel>())?.let {
172+
AllModels(chatModel = it, _json = json)
173+
},
174+
tryDeserialize(node, jacksonTypeRef<UnionMember2>())?.let {
175+
AllModels(unionMember2 = it, _json = json)
176+
},
177+
tryDeserialize(node, jacksonTypeRef<String>())?.let {
178+
AllModels(string = it, _json = json)
179+
},
180+
)
181+
.filterNotNull()
182+
.allMaxBy { it.validity() }
183+
.toList()
184+
return when (bestMatches.size) {
185+
// This can happen if what we're deserializing is completely incompatible with all
186+
// the possible variants (e.g. deserializing from object).
187+
0 -> AllModels(_json = json)
188+
1 -> bestMatches.single()
189+
// If there's more than one match with the highest validity, then use the first
190+
// completely valid match, or simply the first match if none are completely valid.
191+
else -> bestMatches.firstOrNull { it.isValid() } ?: bestMatches.first()
146192
}
147-
148-
return AllModels(_json = json)
149193
}
150194
}
151195

@@ -266,6 +310,33 @@ private constructor(
266310
fun asString(): String =
267311
_value().asString().orElseThrow { OpenAIInvalidDataException("Value is not a String") }
268312

313+
private var validated: Boolean = false
314+
315+
fun validate(): UnionMember2 = apply {
316+
if (validated) {
317+
return@apply
318+
}
319+
320+
known()
321+
validated = true
322+
}
323+
324+
fun isValid(): Boolean =
325+
try {
326+
validate()
327+
true
328+
} catch (e: OpenAIInvalidDataException) {
329+
false
330+
}
331+
332+
/**
333+
* Returns a score indicating how many valid values are contained in this object
334+
* recursively.
335+
*
336+
* Used for best match union deserialization.
337+
*/
338+
@JvmSynthetic internal fun validity(): Int = if (value() == Value._UNKNOWN) 0 else 1
339+
269340
override fun equals(other: Any?): Boolean {
270341
if (this === other) {
271342
return true

openai-java-core/src/main/kotlin/com/openai/models/ChatModel.kt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,32 @@ class ChatModel @JsonCreator private constructor(private val value: JsonField<St
338338
fun asString(): String =
339339
_value().asString().orElseThrow { OpenAIInvalidDataException("Value is not a String") }
340340

341+
private var validated: Boolean = false
342+
343+
fun validate(): ChatModel = apply {
344+
if (validated) {
345+
return@apply
346+
}
347+
348+
known()
349+
validated = true
350+
}
351+
352+
fun isValid(): Boolean =
353+
try {
354+
validate()
355+
true
356+
} catch (e: OpenAIInvalidDataException) {
357+
false
358+
}
359+
360+
/**
361+
* Returns a score indicating how many valid values are contained in this object recursively.
362+
*
363+
* Used for best match union deserialization.
364+
*/
365+
@JvmSynthetic internal fun validity(): Int = if (value() == Value._UNKNOWN) 0 else 1
366+
341367
override fun equals(other: Any?): Boolean {
342368
if (this === other) {
343369
return true

0 commit comments

Comments
 (0)