Skip to content

Commit 55cdf06

Browse files
committed
[fel] record LLM results to memory
1 parent bd9bcfa commit 55cdf06

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

framework/fel/java/fel-flow/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,10 @@
7474
<artifactId>assertj-core</artifactId>
7575
<scope>test</scope>
7676
</dependency>
77+
<dependency>
78+
<groupId>org.mockito</groupId>
79+
<artifactId>mockito-core</artifactId>
80+
<scope>test</scope>
81+
</dependency>
7782
</dependencies>
7883
</project>

framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88

99
import modelengine.fel.core.chat.ChatMessage;
1010
import modelengine.fel.core.chat.Prompt;
11+
import modelengine.fel.core.chat.support.HumanMessage;
12+
import modelengine.fel.core.memory.Memory;
1113
import modelengine.fel.engine.util.StateKey;
1214
import modelengine.fit.waterflow.bridge.fitflow.FitBoundedEmitter;
1315
import modelengine.fit.waterflow.domain.context.FlowSession;
1416
import modelengine.fitframework.flowable.Publisher;
1517
import modelengine.fitframework.inspection.Validation;
1618
import modelengine.fitframework.util.ObjectUtils;
19+
import modelengine.fitframework.util.StringUtils;
1720

1821
/**
1922
* 流式模型发射器。
@@ -26,6 +29,8 @@ public class LlmEmitter<O extends ChatMessage> extends FitBoundedEmitter<O, Chat
2629

2730
private final ChatChunk chunkAcc = new ChatChunk();
2831
private final StreamingConsumer<ChatMessage, ChatMessage> consumer;
32+
private final Memory memory;
33+
private final ChatMessage question;
2934

3035
/**
3136
* 初始化 {@link LlmEmitter}。
@@ -38,6 +43,9 @@ public LlmEmitter(Publisher<O> publisher, Prompt prompt, FlowSession session) {
3843
super(publisher, data -> data);
3944
Validation.notNull(session, "The session cannot be null.");
4045
this.consumer = ObjectUtils.nullIf(session.getInnerState(StateKey.STREAMING_CONSUMER), EMPTY_CONSUMER);
46+
this.memory = session.getInnerState(StateKey.HISTORY);
47+
this.question =
48+
ObjectUtils.getIfNull(session.getInnerState(StateKey.HISTORY_INPUT), () -> getDefaultQuestion(prompt));
4149
}
4250

4351
@Override
@@ -46,4 +54,21 @@ public void emit(ChatMessage data, FlowSession trans) {
4654
this.chunkAcc.merge(data);
4755
this.consumer.accept(this.chunkAcc, data);
4856
}
57+
58+
@Override
59+
public void complete() {
60+
if (this.memory != null && this.chunkAcc.toolCalls().isEmpty()) {
61+
this.memory.add(this.question);
62+
this.memory.add(this.chunkAcc);
63+
}
64+
super.complete();
65+
}
66+
67+
private static ChatMessage getDefaultQuestion(Prompt prompt) {
68+
int size = prompt.messages().size();
69+
if (size == 0) {
70+
return new HumanMessage(StringUtils.EMPTY);
71+
}
72+
return prompt.messages().get(size - 1);
73+
}
4974
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+
* This file is a part of the ModelEngine Project.
4+
* Licensed under the MIT License. See License.txt in the project root for license information.
5+
*--------------------------------------------------------------------------------------------*/
6+
7+
package modelengine.fel.engine.operators.models;
8+
9+
import modelengine.fel.core.chat.ChatMessage;
10+
import modelengine.fel.core.chat.Prompt;
11+
import modelengine.fel.core.chat.support.AiMessage;
12+
import modelengine.fel.core.chat.support.ChatMessages;
13+
import modelengine.fel.core.memory.Memory;
14+
import modelengine.fel.core.tool.ToolCall;
15+
import modelengine.fel.engine.util.StateKey;
16+
import modelengine.fit.waterflow.domain.context.FlowSession;
17+
import modelengine.fitframework.flowable.Choir;
18+
import modelengine.fitframework.util.StringUtils;
19+
20+
import org.junit.jupiter.api.Test;
21+
import org.mockito.ArgumentCaptor;
22+
import org.mockito.Mockito;
23+
24+
import java.util.Arrays;
25+
import java.util.Collections;
26+
import java.util.List;
27+
28+
import static org.junit.jupiter.api.Assertions.*;
29+
30+
/**
31+
* 表示 {@link LlmEmitter} 的测试。
32+
*
33+
* @author 宋永坦
34+
* @since 2025-07-05
35+
*/
36+
class LlmEmitterTest {
37+
@Test
38+
void shouldAddMemoryWhenCompleteGivenLlmOutput() {
39+
String output = "data1";
40+
Prompt prompt = ChatMessages.fromList(Collections.emptyList());
41+
Choir<ChatMessage> dataSource = Choir.create(emitter -> {
42+
emitter.emit(new AiMessage(output));
43+
emitter.complete();
44+
});
45+
FlowSession flowSession = new FlowSession();
46+
Memory mockMemory = Mockito.mock(Memory.class);
47+
ArgumentCaptor<ChatMessage> captor = ArgumentCaptor.forClass(ChatMessage.class);
48+
Mockito.doNothing().when(mockMemory).add(captor.capture());
49+
flowSession.setInnerState(StateKey.HISTORY, mockMemory);
50+
51+
LlmEmitter<ChatMessage> llmEmitter = new LlmEmitter<>(dataSource, prompt, flowSession);
52+
llmEmitter.start(flowSession);
53+
54+
List<ChatMessage> captured = captor.getAllValues();
55+
assertEquals(2, captured.size());
56+
assertEquals(StringUtils.EMPTY, captured.get(0).text());
57+
assertEquals(output, captured.get(1).text());
58+
}
59+
60+
@Test
61+
void shouldNotAddMemoryWhenCompleteGivenLlmToolCallOutput() {
62+
String output = "data1";
63+
Prompt prompt = ChatMessages.fromList(Collections.emptyList());
64+
Choir<ChatMessage> dataSource = Choir.create(emitter -> {
65+
emitter.emit(new AiMessage(output, Arrays.asList(ToolCall.custom().id("id1").build())));
66+
emitter.complete();
67+
});
68+
FlowSession flowSession = new FlowSession();
69+
Memory mockMemory = Mockito.mock(Memory.class);
70+
flowSession.setInnerState(StateKey.HISTORY, mockMemory);
71+
72+
LlmEmitter<ChatMessage> llmEmitter = new LlmEmitter<>(dataSource, prompt, flowSession);
73+
llmEmitter.start(flowSession);
74+
75+
Mockito.verify(mockMemory, Mockito.times(0)).add(Mockito.any());
76+
}
77+
}

0 commit comments

Comments
 (0)