Skip to content

Commit bd9bcfa

Browse files
committed
[fel] implement memory for recent N histories
1 parent 72e8291 commit bd9bcfa

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+
* This file is a part of the ModelEngine Project.
4+
* Licensed under the MIT License. See License.txt in the project root for license information.
5+
*--------------------------------------------------------------------------------------------*/
6+
7+
package modelengine.fel.core.memory.support;
8+
9+
import modelengine.fel.core.chat.ChatMessage;
10+
import modelengine.fel.core.memory.Memory;
11+
import modelengine.fel.core.template.BulkStringTemplate;
12+
import modelengine.fel.core.template.support.DefaultBulkStringTemplate;
13+
import modelengine.fitframework.inspection.Validation;
14+
import modelengine.fitframework.util.MapBuilder;
15+
16+
import java.util.List;
17+
import java.util.Map;
18+
import java.util.concurrent.LinkedBlockingQueue;
19+
import java.util.function.Function;
20+
import java.util.stream.Collectors;
21+
22+
import static modelengine.fitframework.inspection.Validation.notNull;
23+
24+
/**
25+
* 表示使用最近一定次数历史记录的实现。
26+
*
27+
* @author 宋永坦
28+
* @since 2025-07-04
29+
*/
30+
public class RecentMemory implements Memory {
31+
private final LinkedBlockingQueue<ChatMessage> records;
32+
private final BulkStringTemplate bulkTemplate;
33+
private final Function<ChatMessage, Map<String, String>> extractor;
34+
35+
/**
36+
* 指定最大保留历史记录数量的构造方法。
37+
*
38+
* @param maxCount 表示最大保留历史记录数量的 {@code int}。
39+
* @throws IllegalArgumentException 当 {@code maxCount < 0} 时。
40+
*/
41+
public RecentMemory(int maxCount) {
42+
this(maxCount,
43+
new DefaultBulkStringTemplate("{{type}}:{{text}}", "\n"),
44+
message -> MapBuilder.<String, String>get()
45+
.put("type", message.type().getRole())
46+
.put("text", message.text())
47+
.build());
48+
}
49+
50+
/**
51+
* 指定最大保留历史记录数量、渲染模板、抽取方法的构造方法。
52+
*
53+
* @param maxCount 表示最大保留历史记录数量的 {@code int}。
54+
* @param bulkTemplate 表示批量字符串模板的 {@link BulkStringTemplate}。
55+
* @param extractor 表示将 {@link ChatMessage} 转换成
56+
* {@link Map}{@code <}{@link String}, {@link String}{@code >} 的处理函数。
57+
* @throws IllegalArgumentException 当 {@code maxCount < 0}、{@code bulkTemplate}、{@code extractor} 为 {@code null} 时。
58+
*/
59+
public RecentMemory(int maxCount, BulkStringTemplate bulkTemplate,
60+
Function<ChatMessage, Map<String, String>> extractor) {
61+
Validation.greaterThanOrEquals(maxCount, 0, "The max count should >= 0.");
62+
this.records = new LinkedBlockingQueue<>(maxCount);
63+
this.bulkTemplate = notNull(bulkTemplate, "The bulkTemplate cannot be null.");
64+
this.extractor = notNull(extractor, "The extractor cannot be null.");
65+
}
66+
67+
@Override
68+
public void add(ChatMessage message) {
69+
if (!this.records.offer(message)) {
70+
this.records.poll();
71+
this.records.offer(message);
72+
}
73+
}
74+
75+
@Override
76+
public void set(List<ChatMessage> messages) {
77+
messages.forEach(this::add);
78+
}
79+
80+
@Override
81+
public void clear() {
82+
this.records.clear();
83+
}
84+
85+
@Override
86+
public List<ChatMessage> messages() {
87+
return this.records.stream().toList();
88+
}
89+
90+
@Override
91+
public String text() {
92+
return this.records.stream()
93+
.map(this.extractor)
94+
.collect(Collectors.collectingAndThen(Collectors.toList(), this.bulkTemplate::render));
95+
}
96+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+
* This file is a part of the ModelEngine Project.
4+
* Licensed under the MIT License. See License.txt in the project root for license information.
5+
*--------------------------------------------------------------------------------------------*/
6+
7+
package modelengine.fel.core.memory.support;
8+
9+
import modelengine.fel.core.chat.ChatMessage;
10+
import modelengine.fel.core.chat.support.AiMessage;
11+
12+
import org.junit.jupiter.api.Test;
13+
14+
import java.util.Arrays;
15+
import java.util.List;
16+
17+
import static org.junit.jupiter.api.Assertions.*;
18+
19+
/**
20+
* 表示 {@link RecentMemory} 的测试。
21+
*
22+
* @author 宋永坦
23+
* @since 2025-07-04
24+
*/
25+
class RecentMemoryTest {
26+
private final List<ChatMessage> inputChatMessages =
27+
Arrays.asList(new AiMessage("1"), new AiMessage("2"), new AiMessage("3"));
28+
29+
@Test
30+
void shouldKeepAllMessagesWhenAddGivenLessMessage() {
31+
RecentMemory recentMemory = new RecentMemory(4);
32+
this.inputChatMessages.forEach(recentMemory::add);
33+
List<ChatMessage> messages = recentMemory.messages();
34+
35+
assertEquals(inputChatMessages.size(), messages.size());
36+
for (int i = 0; i < inputChatMessages.size(); ++i) {
37+
assertEquals(inputChatMessages.get(i).text(), messages.get(i).text());
38+
}
39+
}
40+
41+
@Test
42+
void shouldKeepMaxCountMessagesWhenAddGivenOverMaxCountMessages() {
43+
RecentMemory recentMemory = new RecentMemory(2);
44+
this.inputChatMessages.forEach(recentMemory::add);
45+
List<ChatMessage> messages = recentMemory.messages();
46+
47+
assertEquals(2, messages.size());
48+
assertEquals(inputChatMessages.get(1).text(), messages.get(0).text());
49+
assertEquals(inputChatMessages.get(2).text(), messages.get(1).text());
50+
}
51+
52+
@Test
53+
void shouldKeepMaxCountMessagesWhenSetGivenOverMaxCountMessages() {
54+
RecentMemory recentMemory = new RecentMemory(2);
55+
recentMemory.set(this.inputChatMessages);
56+
List<ChatMessage> messages = recentMemory.messages();
57+
58+
assertEquals(2, messages.size());
59+
assertEquals(inputChatMessages.get(1).text(), messages.get(0).text());
60+
assertEquals(inputChatMessages.get(2).text(), messages.get(1).text());
61+
}
62+
}

0 commit comments

Comments
 (0)