diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java index f5fab8ce1..9fb28feeb 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java @@ -51,20 +51,30 @@ static TemplatingModuleConfig toTemplateModuleConfig( if (config instanceof TemplateRef) { return config; } + val template = config instanceof Template t ? t : Template.create().template(); val messages = template.getTemplate(); val responseFormat = template.getResponseFormat(); val messagesWithPrompt = new ArrayList<>(messages); + messagesWithPrompt.addAll( prompt.getMessages().stream().map(Message::createChatMessage).toList()); if (messagesWithPrompt.isEmpty()) { throw new IllegalStateException( "A prompt is required. Pass at least one message or configure a template with messages or a template reference."); } - return Template.create() - .template(messagesWithPrompt) - .tools(template.getTools()) - .responseFormat(responseFormat); + + val result = + Template.create() + .template(messagesWithPrompt) + .tools(template.getTools()) + .defaults(template.getDefaults()) + .responseFormat(responseFormat); + + for (val customFieldName : template.getCustomFieldNames()) { + result.setCustomField(customFieldName, template.getCustomField(customFieldName)); + } + return result; } @Nonnull diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java index 242d570af..1a35187dc 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java @@ -62,12 +62,16 @@ void testMergingTemplateConfig() { List.of( systemMessage.createChatMessage(), userMessage.createChatMessage(), - userMessage2.createChatMessage())); + userMessage2.createChatMessage())) + .defaults(Map.of("city", "Paris")); + expected.setCustomField("country", "France"); var prompt = new OrchestrationPrompt(userMessage2); var templateConfig = Template.create() - .template(List.of(systemMessage.createChatMessage(), userMessage.createChatMessage())); + .template(List.of(systemMessage.createChatMessage(), userMessage.createChatMessage())) + .defaults(Map.of("city", "Paris")); + templateConfig.setCustomField("country", "France"); var actual = ConfigToRequestTransformer.toTemplateModuleConfig(prompt, templateConfig); assertThat(actual).isEqualTo(expected);