Skip to content

Commit d5ecdc9

Browse files
Jonas-Isra-d
andauthored
feat: [Orchestration] Enable local prompt templates (#423)
* WIP (first examples work) * Add tests for simple case. * Small fixes * Add simple e2e test * WiP (stuck) * WiP (stuck) * WiP * after meeting with Alex * Fix test * all tests green * small stuff * small stuff * small stuff * fix sample app * fix dependency issue * fix link to docs * use dependencyManagement * small fix * remove unnecessary resourceLoader --------- Co-authored-by: Jonas Israel <[email protected]> Co-authored-by: Alexander Dümont <[email protected]>
1 parent 657ce08 commit d5ecdc9

File tree

15 files changed

+465
-7
lines changed

15 files changed

+465
-7
lines changed

docs/release_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
### ✨ New Functionality
1414

15-
-
15+
- [Orchestration] [Added support to locally test prompt template files](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#locally-test-a-prompt-template)
1616

1717
### 📈 Improvements
1818

orchestration/pom.xml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
<coverage.complexity>83%</coverage.complexity>
3535
<coverage.line>94%</coverage.line>
3636
<coverage.instruction>94%</coverage.instruction>
37-
<coverage.branch>77%</coverage.branch>
37+
<coverage.branch>78%</coverage.branch>
3838
<coverage.method>93%</coverage.method>
3939
<coverage.class>100%</coverage.class>
4040
</properties>
@@ -108,6 +108,10 @@
108108
<groupId>com.github.victools</groupId>
109109
<artifactId>jsonschema-module-jackson</artifactId>
110110
</dependency>
111+
<dependency>
112+
<groupId>com.fasterxml.jackson.dataformat</groupId>
113+
<artifactId>jackson-dataformat-yaml</artifactId>
114+
</dependency>
111115
<!-- scope "provided" -->
112116
<dependency>
113117
<groupId>org.projectlombok</groupId>

orchestration/src/main/java/com/sap/ai/sdk/orchestration/JacksonMixins.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.sap.ai.sdk.orchestration;
22

3+
import com.fasterxml.jackson.annotation.JsonSubTypes;
34
import com.fasterxml.jackson.annotation.JsonTypeInfo;
45
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
56
import com.sap.ai.sdk.orchestration.model.LLMChoice;
@@ -21,4 +22,22 @@ interface ModuleResultsOutputUnmaskingInnerMixIn {}
2122

2223
@JsonTypeInfo(use = JsonTypeInfo.Id.NONE)
2324
interface NoneTypeInfoMixin {}
25+
26+
@JsonTypeInfo(
27+
use = JsonTypeInfo.Id.NAME,
28+
include = JsonTypeInfo.As.PROPERTY,
29+
property = "type",
30+
visible = true)
31+
@JsonSubTypes({
32+
@JsonSubTypes.Type(
33+
value = com.sap.ai.sdk.orchestration.model.ResponseFormatJsonSchema.class,
34+
name = "json_schema"),
35+
@JsonSubTypes.Type(
36+
value = com.sap.ai.sdk.orchestration.model.ResponseFormatJsonObject.class,
37+
name = "json_object"),
38+
@JsonSubTypes.Type(
39+
value = com.sap.ai.sdk.orchestration.model.ResponseFormatText.class,
40+
name = "text")
41+
})
42+
interface ResponseFormatSubTypesMixin {}
2443
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationJacksonConfiguration.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.sap.ai.sdk.orchestration.model.ChatMessage;
99
import com.sap.ai.sdk.orchestration.model.LLMModuleResult;
1010
import com.sap.ai.sdk.orchestration.model.ModuleResultsOutputUnmaskingInner;
11+
import com.sap.ai.sdk.orchestration.model.TemplateResponseFormat;
1112
import javax.annotation.Nonnull;
1213
import lombok.AccessLevel;
1314
import lombok.NoArgsConstructor;
@@ -47,7 +48,12 @@ public static ObjectMapper getOrchestrationObjectMapper() {
4748
.addDeserializer(
4849
ChatMessage.class,
4950
PolymorphicFallbackDeserializer.fromJsonSubTypes(ChatMessage.class))
50-
.setMixInAnnotation(ChatMessage.class, JacksonMixins.NoneTypeInfoMixin.class);
51+
.addDeserializer(
52+
TemplateResponseFormat.class,
53+
PolymorphicFallbackDeserializer.fromJsonSubTypes(TemplateResponseFormat.class))
54+
.setMixInAnnotation(ChatMessage.class, JacksonMixins.NoneTypeInfoMixin.class)
55+
.setMixInAnnotation(
56+
TemplateResponseFormat.class, JacksonMixins.ResponseFormatSubTypesMixin.class);
5157
jackson.registerModule(module);
5258
return jackson;
5359
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationTemplate.java

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
package com.sap.ai.sdk.orchestration;
22

3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import com.fasterxml.jackson.core.JsonProcessingException;
5+
import com.fasterxml.jackson.databind.JsonNode;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
38
import com.google.common.annotations.Beta;
49
import com.sap.ai.sdk.orchestration.model.ChatCompletionTool;
510
import com.sap.ai.sdk.orchestration.model.ChatMessage;
@@ -9,6 +14,7 @@
914
import com.sap.ai.sdk.orchestration.model.Template;
1015
import com.sap.ai.sdk.orchestration.model.TemplateResponseFormat;
1116
import com.sap.ai.sdk.orchestration.model.TemplatingModuleConfig;
17+
import java.io.IOException;
1218
import java.util.ArrayList;
1319
import java.util.HashMap;
1420
import java.util.List;
@@ -35,14 +41,22 @@
3541
@NoArgsConstructor(force = true, access = AccessLevel.PACKAGE)
3642
@Beta
3743
public class OrchestrationTemplate extends TemplateConfig {
38-
@Nullable List<ChatMessage> template;
39-
@Nullable Map<String, String> defaults;
44+
@JsonProperty("template")
45+
@Nullable
46+
List<ChatMessage> template;
47+
48+
@JsonProperty("defaults")
49+
@Nullable
50+
Map<String, String> defaults;
4051

52+
@JsonProperty("response_format")
4153
@With(AccessLevel.PRIVATE)
4254
@Nullable
4355
TemplateResponseFormat responseFormat;
4456

45-
@Nullable List<ChatCompletionTool> tools;
57+
@JsonProperty("tools")
58+
@Nullable
59+
List<ChatCompletionTool> tools;
4660

4761
/**
4862
* Create a low-level representation of the template.
@@ -93,4 +107,45 @@ public OrchestrationTemplate withJsonResponse() {
93107
ResponseFormatJsonObject.create().type(ResponseFormatJsonObject.TypeEnum.JSON_OBJECT);
94108
return this.withResponseFormat(responseFormatJsonObject);
95109
}
110+
111+
/**
112+
* Create a {@link Template} object from a JSON provided as String.
113+
*
114+
* @throws IOException if the JSON cannot be deserialized
115+
* @param inputString the provided JSON
116+
* @return A Template object representing the provided JSON
117+
* @since 1.7.0
118+
*/
119+
@Nullable
120+
private OrchestrationTemplate fromJson(@Nonnull final String inputString) throws IOException {
121+
final ObjectMapper objectMapper =
122+
OrchestrationJacksonConfiguration.getOrchestrationObjectMapper();
123+
final JsonNode rootNode = objectMapper.readTree(inputString);
124+
return objectMapper.treeToValue(rootNode.get("spec"), OrchestrationTemplate.class);
125+
}
126+
127+
/**
128+
* Create a {@link Template} object from a YAML provided as String.
129+
*
130+
* @throws IOException if the YAML cannot be parsed or deserialized
131+
* @param inputYaml the provided YAML
132+
* @return A Template object representing the provided YAML
133+
* @since 1.7.0
134+
*/
135+
@Nullable
136+
public OrchestrationTemplate fromYaml(@Nonnull final String inputYaml) throws IOException {
137+
final Object obj;
138+
try {
139+
final ObjectMapper yamlReader = new ObjectMapper(new YAMLFactory());
140+
obj = yamlReader.readValue(inputYaml, Object.class);
141+
} catch (JsonProcessingException ex) {
142+
throw new IOException("Failed to parse the YAML input: " + ex.getMessage(), ex);
143+
}
144+
try {
145+
final ObjectMapper jsonWriter = new ObjectMapper();
146+
return fromJson(jsonWriter.writeValueAsString(obj));
147+
} catch (JsonProcessingException ex) {
148+
throw new IOException("Failed to deserialize the input: " + ex.getMessage(), ex);
149+
}
150+
}
96151
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationConvenienceUnitTest.java

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import com.sap.ai.sdk.orchestration.model.TemplateRef;
1515
import com.sap.ai.sdk.orchestration.model.TemplateRefByID;
1616
import com.sap.ai.sdk.orchestration.model.TemplateRefByScenarioNameVersion;
17+
import java.io.IOException;
18+
import java.nio.file.Files;
19+
import java.nio.file.Path;
1720
import java.util.LinkedHashMap;
1821
import java.util.List;
1922
import java.util.Map;
@@ -177,4 +180,96 @@ void testTemplateReferenceConstruction() {
177180
assertThat(templateReferenceScenarioNameVersion.toLowLevel())
178181
.isEqualTo(templateReferenceScenarioNameVersionLowLevel);
179182
}
183+
184+
@Test
185+
void testTemplateFromLocalFileWithJsonSchemaAndTools() throws IOException {
186+
String promptTemplateYaml =
187+
Files.readString(Path.of("src/test/resources/promptTemplateExample.yaml"));
188+
var templateWithJsonSchemaTools = TemplateConfig.create().fromYaml(promptTemplateYaml);
189+
var schema =
190+
Map.of(
191+
"type",
192+
"object",
193+
"properties",
194+
Map.of(
195+
"language", Map.of("type", "string"),
196+
"translation", Map.of("type", "string")),
197+
"required",
198+
List.of("language", "translation"),
199+
"additionalProperties",
200+
false);
201+
var expectedTemplateWithJsonSchemaTools =
202+
OrchestrationTemplate.create()
203+
.withTemplate(
204+
List.of(
205+
SingleChatMessage.create()
206+
.role("system")
207+
.content("You are a language translator."),
208+
SingleChatMessage.create()
209+
.role("user")
210+
.content("Whats {{ ?word }} in {{ ?language }}?")))
211+
.withDefaults(Map.of("word", "apple"))
212+
.withJsonSchemaResponse(
213+
ResponseJsonSchema.fromMap(schema, "translation-schema")
214+
.withDescription("Translate the given word into the provided language.")
215+
.withStrict(true))
216+
.withTools(
217+
List.of(
218+
ChatCompletionTool.create()
219+
.type(ChatCompletionTool.TypeEnum.FUNCTION)
220+
.function(
221+
FunctionObject.create()
222+
.name("translate")
223+
.parameters(
224+
Map.of(
225+
"type",
226+
"object",
227+
"additionalProperties",
228+
false,
229+
"required",
230+
List.of("language", "wordToTranslate"),
231+
"properties",
232+
Map.of(
233+
"language", Map.of("type", "string"),
234+
"wordToTranslate", Map.of("type", "string"))))
235+
.description("Translate a word.")
236+
.strict(true))));
237+
assertThat(templateWithJsonSchemaTools).isEqualTo(expectedTemplateWithJsonSchemaTools);
238+
}
239+
240+
@Test
241+
void testTemplateFromLocalFileWithJsonObject() throws IOException {
242+
String promptTemplateWithJsonObject =
243+
"""
244+
name: translator
245+
version: 0.0.1
246+
scenario: translation scenario
247+
spec:
248+
template:
249+
- role: "system"
250+
content: |-
251+
You are a language translator.
252+
- role: "user"
253+
content: |-
254+
Whats {{ ?word }} in {{ ?language }}?
255+
defaults:
256+
word: "apple"
257+
response_format:
258+
type: json_object
259+
""";
260+
var templateWithJsonObject = TemplateConfig.create().fromYaml(promptTemplateWithJsonObject);
261+
var expectedTemplateWithJsonObject =
262+
OrchestrationTemplate.create()
263+
.withTemplate(
264+
List.of(
265+
SingleChatMessage.create()
266+
.role("system")
267+
.content("You are a language translator."),
268+
SingleChatMessage.create()
269+
.role("user")
270+
.content("Whats {{ ?word }} in {{ ?language }}?")))
271+
.withDefaults(Map.of("word", "apple"))
272+
.withJsonResponse();
273+
assertThat(templateWithJsonObject).isEqualTo(expectedTemplateWithJsonObject);
274+
}
180275
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
5656
import java.io.IOException;
5757
import java.io.InputStream;
58+
import java.nio.file.Files;
59+
import java.nio.file.Path;
5860
import java.util.List;
5961
import java.util.Map;
6062
import java.util.Objects;
@@ -978,4 +980,46 @@ void testTemplateFromPromptRegistryByScenario() throws IOException {
978980
verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request)));
979981
}
980982
}
983+
984+
@Test
985+
void testTemplateFromInput() throws IOException {
986+
stubFor(
987+
post(anyUrl())
988+
.willReturn(
989+
aResponse()
990+
.withBodyFile("templateReferenceResponse.json")
991+
.withHeader("Content-Type", "application/json")));
992+
993+
var promptTemplateYaml =
994+
Files.readString(Path.of("src/test/resources/promptTemplateExample.yaml"));
995+
996+
var template = TemplateConfig.create().fromYaml(promptTemplateYaml);
997+
var configWithTemplate = template != null ? config.withTemplateConfig(template) : config;
998+
999+
var inputParams = Map.of("language", "German");
1000+
var prompt = new OrchestrationPrompt(inputParams);
1001+
1002+
final var response = client.chatCompletion(prompt, configWithTemplate);
1003+
1004+
try (var requestInputStream = fileLoader.apply("localTemplateRequest.json")) {
1005+
final String request = new String(requestInputStream.readAllBytes());
1006+
verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request)));
1007+
}
1008+
}
1009+
1010+
@Test
1011+
void testTemplateFromInputThrows() {
1012+
assertThatThrownBy(() -> TemplateConfig.create().fromYaml(": what?"))
1013+
.isInstanceOf(IOException.class)
1014+
.hasMessageContaining("Failed to parse");
1015+
1016+
prompt = new OrchestrationPrompt(Map.of());
1017+
assertThatThrownBy(
1018+
() ->
1019+
TemplateConfig.create()
1020+
.fromYaml(
1021+
"name: translator\nversion: 0.0.1\nscenario: translation scenario\nspec:\n template: what?"))
1022+
.isInstanceOf(IOException.class)
1023+
.hasMessageContaining("Failed to deserialize");
1024+
}
9811025
}

0 commit comments

Comments
 (0)