Skip to content

Commit d3867c2

Browse files
SlobevgСлободской Евгений Геннадьевич
andauthored
#83 | передача functions_state_id через metadata из choice ответа модели в последующие запросы (#84)
Co-authored-by: Слободской Евгений Геннадьевич <Slobodskoy.E.Ge@sberbank.ru>
1 parent f1217b1 commit d3867c2

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

spring-ai-gigachat/src/main/java/chat/giga/springai/GigaChatModel.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ private CompletionRequest createRequest(Prompt prompt, boolean stream) {
328328
.functionCall(new CompletionResponse.FunctionCall()
329329
.setName(toolCall.name())
330330
.setArguments(toolCall.arguments()));
331+
} else if (message.getMetadata().containsKey("functions_state_id")) {
332+
messageBuilder.functionsStateId(
333+
(String) message.getMetadata().get("functions_state_id"));
331334
}
332335
return List.of(messageBuilder.build());
333336
} else if (message instanceof ToolResponseMessage toolResponseMessage) {
@@ -476,18 +479,17 @@ private ChatResponse toChatResponse(CompletionResponse completionResponse, Usage
476479
private Generation buildGeneration(String id, CompletionResponse.Choice choice, boolean streaming) {
477480
CompletionResponse.MessagesRes message = streaming ? choice.getDelta() : choice.getMessage();
478481
String finishReason = choice.getFinishReason() != null ? choice.getFinishReason() : "";
479-
Map<String, Object> metadata = Map.of(
480-
"id",
481-
id,
482-
"index",
483-
choice.getIndex(),
484-
"role",
485-
message.getRole() != null ? message.getRole().name() : "",
486-
"finishReason",
487-
finishReason);
482+
String functionsStateId = message.getFunctionsStateId();
483+
Map<String, Object> metadata = new HashMap<>();
484+
metadata.put("id", id);
485+
metadata.put("index", choice.getIndex());
486+
metadata.put("role", message.getRole() != null ? message.getRole().name() : "");
487+
metadata.put("finishReason", finishReason);
488+
if (functionsStateId != null) {
489+
metadata.put("functions_state_id", functionsStateId);
490+
}
488491
List<AssistantMessage.ToolCall> toolCalls;
489492
if (CompletionResponse.FinishReason.FUNCTION_CALL.equals(finishReason)) {
490-
String functionsStateId = message.getFunctionsStateId();
491493
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall(
492494
functionsStateId,
493495
"function",
@@ -500,7 +502,7 @@ private Generation buildGeneration(String id, CompletionResponse.Choice choice,
500502
var assistantMessage = AssistantMessage.builder()
501503
.content(message.getContent())
502504
.toolCalls(toolCalls)
503-
.properties(metadata)
505+
.properties(Collections.unmodifiableMap(metadata))
504506
.build();
505507
var generationMetadata = ChatGenerationMetadata.builder()
506508
.finishReason(choice.getFinishReason())

spring-ai-gigachat/src/test/java/chat/giga/springai/GigaChatModelTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,30 @@ void testGigaChatOptions_withFunctionCallParam() {
100100
assertEquals(functionCallParam, requestFunctionCallParam);
101101
}
102102

103+
@Test
104+
void testGigaChatOptions_withFunctionsStateIdAndFinishReasonStop() {
105+
var functionCallback = GigaTools.from(new TestTool());
106+
107+
var prompt = new Prompt(
108+
List.of(new UserMessage("Hello")),
109+
GigaChatOptions.builder()
110+
.model(GigaChatApi.ChatModel.GIGA_CHAT)
111+
.toolCallbacks(List.of(functionCallback))
112+
.build());
113+
114+
when(gigaChatApi.chatCompletionEntity(any(), any()))
115+
.thenReturn(new ResponseEntity<>(response, HttpStatusCode.valueOf(200)));
116+
117+
when(response.getChoices())
118+
.thenReturn(List.of(new CompletionResponse.Choice()
119+
.setMessage(new CompletionResponse.MessagesRes().setFunctionsStateId("uuid"))));
120+
121+
ChatResponse chatResponse = gigaChatModel.internalCall(prompt, null);
122+
123+
assertTrue(chatResponse.getResult().getOutput().getMetadata().containsKey("functions_state_id"));
124+
assertEquals("uuid", chatResponse.getResult().getOutput().getMetadata().get("functions_state_id"));
125+
}
126+
103127
@Test
104128
void testGigaChatOptions_withFunctionCallEmptyAndTool() {
105129
var functionCallback = GigaTools.from(new TestTool());

0 commit comments

Comments
 (0)