Skip to content

Commit 4e7eb97

Browse files
committed
Add test and improve deserializer
- Make deserializer lenient to allowed values
1 parent 68902f3 commit 4e7eb97

File tree

3 files changed

+134
-34
lines changed

3 files changed

+134
-34
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/LLMModuleResultDeserializer.java

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultStreaming;
88
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
99
import java.io.IOException;
10+
import javax.annotation.Nonnull;
1011

1112
/**
1213
* A deserializer for {@link LLMModuleResult} that determines the concrete implementation based on
@@ -21,58 +22,61 @@ public LLMModuleResultDeserializer() {
2122
/**
2223
* Deserialize the JSON object into one of the subtypes of the base type.
2324
*
25+
* <ul>
26+
* <li>If elements of "choices" array contains "delta", deserialize into {@link
27+
* LLMModuleResultStreaming}.
28+
* <li>Otherwise, deserialize into {@link LLMModuleResultSynchronous}.
29+
* </ul>
30+
*
2431
* @param parser The JSON parser.
2532
* @param context The deserialization context.
2633
* @return The deserialized object.
2734
* @throws IOException If an I/O error occurs.
2835
*/
2936
@Override
30-
public LLMModuleResult deserialize(JsonParser parser, DeserializationContext context)
37+
public LLMModuleResult deserialize(JsonParser parser, @Nonnull DeserializationContext context)
3138
throws IOException {
3239

3340
// Check if the target type is a concrete class
3441
JavaType targetType = context.getContextualType();
35-
3642
if (targetType != null && !LLMModuleResult.class.equals(targetType.getRawClass())) {
37-
// If we're deserializing a concrete class, delegate to the default deserializer
38-
JsonDeserializer<Object> defaultDeserializer = context.findRootValueDeserializer(targetType);
39-
return (LLMModuleResult) defaultDeserializer.deserialize(parser, context);
43+
return delegateToDefaultDeserializer(parser, context, targetType);
4044
}
4145

4246
// Custom deserialization logic for LLMModuleResult interface
43-
ObjectMapper mapper = (ObjectMapper) parser.getCodec();
44-
JsonNode rootNode = mapper.readTree(parser);
47+
var mapper = (ObjectMapper) parser.getCodec();
48+
var rootNode = mapper.readTree(parser);
49+
Class<? extends LLMModuleResult> concreteClass = LLMModuleResultSynchronous.class;
4550

4651
// Inspect the "choices" field
47-
JsonNode choicesNode = rootNode.get("choices");
48-
52+
var choicesNode = rootNode.get("choices");
4953
if (choicesNode != null && choicesNode.isArray()) {
50-
JsonNode firstChoice = choicesNode.get(0);
51-
if (firstChoice != null) {
52-
Class<? extends LLMModuleResult> concreteClass = null;
53-
if (firstChoice.has("delta")) {
54-
concreteClass = LLMModuleResultStreaming.class;
55-
} else if (firstChoice.has("message")) {
56-
concreteClass = LLMModuleResultSynchronous.class;
57-
}
58-
59-
if (concreteClass != null) {
60-
// Deserialize into the determined concrete class
61-
// Create a new parser for the root node
62-
JsonParser rootParser = rootNode.traverse(mapper);
63-
rootParser.nextToken(); // Advance to the first token
64-
65-
// Use the default deserializer for the concrete class
66-
JsonDeserializer<?> deserializer =
67-
context.findRootValueDeserializer(context.constructType(concreteClass));
68-
69-
return (LLMModuleResult) deserializer.deserialize(rootParser, context);
70-
}
54+
var firstChoice = (JsonNode) choicesNode.get(0);
55+
if (firstChoice != null && firstChoice.has("delta")) {
56+
concreteClass = LLMModuleResultStreaming.class;
7157
}
7258
}
7359

74-
// If unable to determine, throw an exception or handle default case
75-
throw new JsonMappingException(
76-
parser, "Unable to determine the concrete implementation of LLMModuleResult");
60+
// Create a new parser for the root node
61+
var rootParser = rootNode.traverse(mapper);
62+
rootParser.nextToken(); // Advance to the first token
63+
64+
// Use the default deserializer for the concrete class
65+
return delegateToDefaultDeserializer(rootParser, context, mapper.constructType(concreteClass));
66+
}
67+
68+
/**
69+
* Delegate deserialization to the default deserializer for the given concrete type.
70+
*
71+
* @param parser The JSON parser.
72+
* @param context The deserialization context.
73+
* @param concreteType The concrete type to deserialize into.
74+
* @return The deserialized object.
75+
* @throws IOException If an I/O error occurs.
76+
*/
77+
private LLMModuleResult delegateToDefaultDeserializer(
78+
JsonParser parser, DeserializationContext context, JavaType concreteType) throws IOException {
79+
var defaultDeserializer = context.findRootValueDeserializer(concreteType);
80+
return (LLMModuleResult) defaultDeserializer.deserialize(parser, context);
7781
}
7882
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/client/model/DPIConfig.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
value = "org.openapitools.codegen.languages.JavaClientCodegen",
3333
comments = "Generator version: 7.9.0")
3434
public class DPIConfig implements MaskingProviderConfig {
35-
public static final String JSON_PROPERTY_TYPE = "type";
3635
private MethodEnum method;
36+
37+
public static final String JSON_PROPERTY_TYPE = "type";
3738
private TypeEnum type;
3839

3940
/** Create a builder with no initialized field. */
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
4+
5+
import com.fasterxml.jackson.databind.ObjectMapper;
6+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult;
7+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultStreaming;
8+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
9+
import lombok.SneakyThrows;
10+
import org.junit.jupiter.api.Test;
11+
12+
class LLMModuleResultDeserializerTest {
13+
14+
private final ObjectMapper objectMapper = OrchestrationClient.JACKSON;
15+
private final String jsonTemplate =
16+
"""
17+
{
18+
"object": "chat.completion",
19+
"id": "chatcmpl-12345",
20+
"created": 1234567890,
21+
"model": "gpt-3.5-turbo",
22+
"choices": %s,
23+
"usage": {
24+
"completion_tokens": 10,
25+
"prompt_tokens": 20,
26+
"total_tokens": 30
27+
}
28+
}
29+
""";
30+
31+
@SneakyThrows
32+
@Test
33+
void testSubtypeResolutionSynchronous() {
34+
35+
var choices =
36+
"""
37+
[
38+
{
39+
"index": 0,
40+
"message": {
41+
"role": "assistant",
42+
"content": "Sample response content."
43+
},
44+
"finish_reason": "length"
45+
}
46+
]
47+
""";
48+
49+
var json = String.format(jsonTemplate, choices);
50+
51+
// Deserialize JSON content
52+
LLMModuleResult result = objectMapper.readValue(json, LLMModuleResult.class);
53+
54+
// Assert
55+
assertThat(result).isExactlyInstanceOf(LLMModuleResultSynchronous.class);
56+
}
57+
58+
@SneakyThrows
59+
@Test
60+
void testSubtypeResolutionStreaming() {
61+
var choices =
62+
"""
63+
[
64+
{
65+
"index": 0,
66+
"delta": {
67+
"content": "Sample response content."
68+
}
69+
}
70+
]
71+
""";
72+
73+
var json = String.format(jsonTemplate, choices);
74+
75+
// Deserialize JSON content
76+
LLMModuleResult result = objectMapper.readValue(json, LLMModuleResult.class);
77+
78+
// Assert
79+
assertThat(result).isExactlyInstanceOf(LLMModuleResultStreaming.class);
80+
}
81+
82+
@SneakyThrows
83+
@Test
84+
void testSubtypeResolutionEmptyChoices() {
85+
86+
String choice = "[]";
87+
var json = String.format(jsonTemplate, choice);
88+
89+
// Deserialize JSON content
90+
LLMModuleResult result = objectMapper.readValue(json, LLMModuleResult.class);
91+
92+
// Assert
93+
assertThat(result).isExactlyInstanceOf(LLMModuleResultSynchronous.class);
94+
}
95+
}

0 commit comments

Comments
 (0)