Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/fel-example/05-retrieval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ node0-->node1{{=}}

## 验证

- 在IDEA中运行`DemoApplication`

- 在浏览器栏输入:`http://localhost:8080/ai/example/chat?query=请介绍一下黑神话悟空`

```json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import modelengine.fitframework.annotation.Component;
import modelengine.fitframework.annotation.Fit;
import modelengine.fitframework.annotation.Value;
import modelengine.fitframework.log.Logger;
import modelengine.fitframework.serialization.ObjectSerializer;
import modelengine.fitframework.util.FileUtils;

Expand All @@ -57,6 +58,7 @@
@Component
@RequestMapping("/ai/example")
public class RetrievalExampleController {
private static final Logger log = Logger.get(RetrievalExampleController.class);
private static final String REWRITE_PROMPT =
"作为一个向量检索助手,你的任务是结合历史记录,为”原问题“生成”检索词“," + "生成的问题要求指向对象清晰明确,并与“原问题语言相同。\n\n"
+ "历史记录:\n---\n" + DEFAULT_HISTORY_KEY + "---\n原问题:{{query}}\n检索词:";
Expand Down Expand Up @@ -85,22 +87,27 @@ public RetrievalExampleController(ChatModel chatModel, EmbedModel embedModel,
.others(node -> node.map(tip -> tip.freeze().get("query").text()))
.retrieve(new DefaultVectorRetriever(vectorStore, SearchOption.custom().topK(1).build()))
.synthesize(docs -> Content.from(docs.stream().map(Document::text).collect(Collectors.joining("\n\n"))))
.close();
.close(__ -> log.info("Retrieve flow completed."));

AiProcessFlow<File, List<Document>> indexFlow = AiFlows.<File>create()
.load(new JsonFileSource(serializer, StringTemplate.create("{{question}}: {{answer}}")))
.index(vectorStore)
.close();
File file = FileUtils.file(this.getClass().getClassLoader().getResource("data.json"));
notNull(file, "The data cannot be null.");
indexFlow.converse().offer(file);
indexFlow.converse()
.doOnError(e -> log.info("Index build error. [error={}]", e.getMessage(), e))
.doOnFinally(() -> log.info("Index build successfully."))
.offer(file);

this.ragFlow = AiFlows.<String>create()
.just(query -> log.info("RAG flow start. [query={}]", query))
.map(query -> Tip.from("query", query))
.runnableParallel(value("context", retrieveFlow), passThrough())
.prompt(Prompts.history(), Prompts.human(CHAT_PROMPT))
.just(__ -> log.info("LLM start generation."))
.generate(chatFlowModel)
.close();
.close(__ -> log.info("RAG flow completed."));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

package modelengine.fit.waterflow.domain.context;

import modelengine.fit.waterflow.domain.context.repo.flowsession.FlowSessionRepo;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
Expand All @@ -21,8 +22,6 @@
* @since 1.0
*/
public class MatchWindow extends Window {
private static final Map<String, MatchWindow> all = new ConcurrentHashMap<>();

private final Set<MatchWindow> arms = new HashSet<>();

/**
Expand All @@ -41,22 +40,26 @@ public MatchWindow(Window source, UUID id, Object data) {
/**
* 创建一个MatchWindow
*
* @param flowId 流程ID
* @param source 源窗口
* @param id 窗口ID
* @param data 窗口数据
* @return 返回创建的MatchWindow对象
*/
public static synchronized MatchWindow from(Window source, UUID id, Object data) {
MatchWindow window = all.get(id.toString());
public static synchronized MatchWindow from(String flowId, Window source, UUID id, Object data) {
// 从 FlowSessionRepo 获取缓存
Map<UUID, MatchWindow> cache = FlowSessionRepo.getMatchWindowCache(flowId, source.getSession());

MatchWindow window = cache.get(id);
if (window == null) {
window = new MatchWindow(source, id, data);
FlowSession session = new FlowSession(source.getSession());
session.setWindow(window);
all.put(id.toString(), window);
cache.put(id, window);
}
WindowToken token = window.createToken();
token.beginConsume();
List<MatchWindow> arms = all.values().stream().filter(t -> t.from == source).collect(Collectors.toList());
List<MatchWindow> arms = cache.values().stream().filter(t -> t.from == source).collect(Collectors.toList());
for (MatchWindow a : arms) {
a.setArms(arms);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import modelengine.fit.waterflow.domain.context.FlatMapSourceWindow;
import modelengine.fit.waterflow.domain.context.FlowSession;
import modelengine.fit.waterflow.domain.context.MatchWindow;
import modelengine.fit.waterflow.domain.context.Window;
import modelengine.fit.waterflow.domain.context.repo.flowcontext.FlowContextRepo;
import modelengine.fitframework.inspection.Validation;
Expand Down Expand Up @@ -90,6 +91,19 @@ public static FlatMapSourceWindow getFlatMapSource(String flowId, Window window,
.getFlatMapSourceWindow(window, repo);
}

/**
* 获取 MatchWindow 缓存 Map,用于存储和检索 MatchWindow 实例
*
* @param flowId The unique identifier of the flow.
* @param session The current session context.
* @return MatchWindow 缓存 Map
*/
public static Map<UUID, MatchWindow> getMatchWindowCache(String flowId, FlowSession session) {
Validation.notNull(flowId, "Flow id cannot be null.");
Validation.notNull(session, "Session cannot be null.");
return getFlowSessionCache(flowId, session).getMatchWindowCache();
}

/**
* Releases all resources associated with a specific flow session.
*
Expand Down Expand Up @@ -137,6 +151,12 @@ private static class FlowSessionCache {
*/
private final Map<UUID, FlatMapSourceWindow> flatMapSourceWindows = new ConcurrentHashMap<>();

/**
* 记录流程中条件匹配节点产生的窗口信息,用于将同一批数据汇聚。
* 其中索引为 match window 的唯一标识。
*/
private final Map<UUID, MatchWindow> matchWindows = new ConcurrentHashMap<>();

private final Map<String, Integer> accOrders = new ConcurrentHashMap<>();

private FlowSession getNextToSession(FlowSession session) {
Expand Down Expand Up @@ -165,6 +185,10 @@ private FlatMapSourceWindow getFlatMapSourceWindow(Window window, FlowContextRep
});
}

private Map<UUID, MatchWindow> getMatchWindowCache() {
return this.matchWindows;
}

private int getNextAccOrder(String nodeId) {
return this.accOrders.compute(nodeId, (key, value) -> {
if (value == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public MatchHappen<O, D, I, F> match(Operators.Whether<I> whether,
Operators.BranchProcessor<O, D, I, F> processor) {
UUID id = UUID.randomUUID();
State<I, D, I, F> branchStart = new State<>(this.node.publisher()
.just(input -> input.setSession(
MatchWindow.from(input.getWindow(), id, input.getData()).getSession()), whether)
.just(input -> input.setSession(MatchWindow.from(this.node.processor.getStreamId(),
input.getWindow(), id, input.getData()).getSession()), whether)
.displayAs(SpecialDisplayNode.BRANCH.name()), this.node.getFlow());
State<O, D, ?, F> branch = processor.process(branchStart);
this.branches.add(branch);
Expand Down