Skip to content

Commit 80924d1

Browse files
newtorka-d
andauthored
fix: [Orchestration] default values for template messages (#391)
* Initial * Add test assertion --------- Co-authored-by: Alexander Dümont <[email protected]>
1 parent f4bb4e9 commit 80924d1

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,30 @@ static TemplatingModuleConfig toTemplateModuleConfig(
5151
if (config instanceof TemplateRef) {
5252
return config;
5353
}
54+
5455
val template = config instanceof Template t ? t : Template.create().template();
5556
val messages = template.getTemplate();
5657
val responseFormat = template.getResponseFormat();
5758
val messagesWithPrompt = new ArrayList<>(messages);
59+
5860
messagesWithPrompt.addAll(
5961
prompt.getMessages().stream().map(Message::createChatMessage).toList());
6062
if (messagesWithPrompt.isEmpty()) {
6163
throw new IllegalStateException(
6264
"A prompt is required. Pass at least one message or configure a template with messages or a template reference.");
6365
}
64-
return Template.create()
65-
.template(messagesWithPrompt)
66-
.tools(template.getTools())
67-
.responseFormat(responseFormat);
66+
67+
val result =
68+
Template.create()
69+
.template(messagesWithPrompt)
70+
.tools(template.getTools())
71+
.defaults(template.getDefaults())
72+
.responseFormat(responseFormat);
73+
74+
for (val customFieldName : template.getCustomFieldNames()) {
75+
result.setCustomField(customFieldName, template.getCustomField(customFieldName));
76+
}
77+
return result;
6878
}
6979

7080
@Nonnull

orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,16 @@ void testMergingTemplateConfig() {
6262
List.of(
6363
systemMessage.createChatMessage(),
6464
userMessage.createChatMessage(),
65-
userMessage2.createChatMessage()));
65+
userMessage2.createChatMessage()))
66+
.defaults(Map.of("city", "Paris"));
67+
expected.setCustomField("country", "France");
6668

6769
var prompt = new OrchestrationPrompt(userMessage2);
6870
var templateConfig =
6971
Template.create()
70-
.template(List.of(systemMessage.createChatMessage(), userMessage.createChatMessage()));
72+
.template(List.of(systemMessage.createChatMessage(), userMessage.createChatMessage()))
73+
.defaults(Map.of("city", "Paris"));
74+
templateConfig.setCustomField("country", "France");
7175
var actual = ConfigToRequestTransformer.toTemplateModuleConfig(prompt, templateConfig);
7276

7377
assertThat(actual).isEqualTo(expected);

0 commit comments

Comments
 (0)