Skip to content

Commit 6fcb39a

Browse files
authored
Merge pull request #1587 from didiforgithub/main
Add Features to AFLOW.
2 parents 532aede + e2cdcfb commit 6fcb39a

File tree

2 files changed

+104
-6
lines changed

2 files changed

+104
-6
lines changed

metagpt/actions/action_node.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,22 +541,22 @@ async def code_fill(
541541
result = {field_name: extracted_code}
542542
return result
543543

544-
async def single_fill(self, context: str) -> Dict[str, str]:
544+
async def single_fill(self, context: str, images: Optional[Union[str, list[str]]] = None) -> Dict[str, str]:
545545
field_name = self.get_field_name()
546546
prompt = context
547-
content = await self.llm.aask(prompt)
547+
content = await self.llm.aask(prompt, images=images)
548548
result = {field_name: content}
549549
return result
550550

551-
async def xml_fill(self, context: str) -> Dict[str, Any]:
551+
async def xml_fill(self, context: str, images: Optional[Union[str, list[str]]] = None) -> Dict[str, Any]:
552552
"""
553553
Fill context with XML tags and convert according to field types, including string, integer, boolean, list and dict types
554554
"""
555555
field_names = self.get_field_names()
556556
field_types = self.get_field_types()
557557

558558
extracted_data: Dict[str, Any] = {}
559-
content = await self.llm.aask(context)
559+
content = await self.llm.aask(context, images=images)
560560

561561
for field_name in field_names:
562562
pattern = rf"<{field_name}>(.*?)</{field_name}>"
@@ -635,12 +635,12 @@ async def fill(
635635

636636
elif mode == FillMode.XML_FILL.value:
637637
context = self.xml_compile(context=self.context)
638-
result = await self.xml_fill(context)
638+
result = await self.xml_fill(context, images=images)
639639
self.instruct_content = self.create_class()(**result)
640640
return self
641641

642642
elif mode == FillMode.SINGLE_FILL.value:
643-
result = await self.single_fill(context)
643+
result = await self.single_fill(context, images=images)
644644
self.instruct_content = self.create_class()(**result)
645645
return self
646646

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# -*- coding: utf-8 -*-
2+
# @Date : 2024-03-21
3+
# @Author : Your Name
4+
# @Desc : Interface for AFLOW
5+
6+
import asyncio
7+
import importlib.util
8+
import sys
9+
from pathlib import Path
10+
from typing import Optional, Tuple
11+
12+
from metagpt.configs.models_config import ModelsConfig
13+
from metagpt.ext.aflow.scripts.evaluator import DatasetType
14+
from metagpt.ext.aflow.scripts.optimizer_utils.data_utils import DataUtils
15+
from metagpt.logs import logger
16+
17+
18+
def load_best_round(dataset: str, optimized_path: str = "metagpt/ext/aflow/scripts/optimized") -> int:
19+
"""加载最佳表现的轮次"""
20+
data_utils = DataUtils(f"{optimized_path}/{dataset}")
21+
22+
# 使用get_top_rounds获取得分最高的轮次
23+
top_rounds = data_utils.get_top_rounds(sample=2, mode="Graph")
24+
if not top_rounds[1]:
25+
return 1
26+
27+
return top_rounds[1]["round"]
28+
29+
30+
def load_workflow_class(graph_path: str):
31+
"""动态加载工作流类"""
32+
spec = importlib.util.spec_from_file_location("workflow_module", graph_path)
33+
module = importlib.util.module_from_spec(spec)
34+
sys.modules["workflow_module"] = module
35+
spec.loader.exec_module(module)
36+
return module.Workflow
37+
38+
39+
async def aflow_inference(
40+
dataset: DatasetType,
41+
question: str,
42+
entry_point: Optional[str] = None,
43+
round: Optional[int] = None,
44+
llm_name: str = "gpt-4o-mini",
45+
optimized_path: str = "metagpt/ext/aflow/scripts/optimized",
46+
) -> Tuple[str, float]:
47+
"""AFLOW推理接口
48+
49+
Args:
50+
dataset: 数据集名称
51+
question: 输入问题
52+
round: 指定使用的轮次,如果为None则使用最佳轮次
53+
llm_name: 使用的LLM模型名称
54+
optimized_path: 优化结果保存路径
55+
56+
Returns:
57+
(答案, 成本)的元组
58+
"""
59+
# 如果没有指定轮次,使用最佳轮次
60+
if round is None:
61+
round = load_best_round(dataset, optimized_path)
62+
63+
logger.info(f"Using round {round} for inference")
64+
65+
# 构建工作流路径并加载
66+
graph_path = Path(optimized_path) / dataset / "workflows" / f"round_{round}" / "graph.py"
67+
if not graph_path.exists():
68+
raise FileNotFoundError(f"Workflow file not found: {graph_path}")
69+
70+
# 动态加载工作流类
71+
WorkflowClass = load_workflow_class(str(graph_path))
72+
73+
# 创建工作流实例
74+
llm_config = ModelsConfig.default().get(llm_name)
75+
workflow = WorkflowClass(
76+
name=f"{dataset}_workflow",
77+
llm_config=llm_config,
78+
dataset=dataset,
79+
)
80+
81+
# 执行推理
82+
if dataset in ["MBPP", "HumanEval"]:
83+
# 代码类任务需要额外的entry_point参数
84+
answer, cost = await workflow(question, entry_point=entry_point)
85+
else:
86+
answer, cost = await workflow(question)
87+
88+
return answer, cost
89+
90+
91+
if __name__ == "__main__":
92+
asyncio.run(
93+
aflow_inference(
94+
dataset="MBPP",
95+
question="write a function named add_two_numbers to calculate the sum of two numbers",
96+
entry_point="add_two_numbers",
97+
)
98+
)

0 commit comments

Comments
 (0)