diff --git a/dump.rdb b/dump.rdb new file mode 100644 index 0000000000..d08001ae89 Binary files /dev/null and b/dump.rdb differ diff --git a/metagpt/ext/spo_api_backend/Readme.md b/metagpt/ext/spo_api_backend/Readme.md new file mode 100644 index 0000000000..849d4cee9a --- /dev/null +++ b/metagpt/ext/spo_api_backend/Readme.md @@ -0,0 +1,28 @@ +# SPO API BACKEND +**This is an SPO API interface that allows you to easily integrate prompt optimization capabilities into your own applications using this backend,This API backend supports task queues, so you can add many tasks, and the program will optimize prompt words in order** + +**This code is built based on some modules in metagpt \ ext \ spo. If necessary, you can integrate the code used in metagpt \ ext \ spo into spo_mapi_mackend by yourself** + +**I have written a simple graphical operation page using streamlit for this API, metagpt\ext\spo_api_backend\frontend_sample\spo_gui.py。 You can use this tool to easily optimize a large number of prompt words** + +*To fully run the functionality I have written, you can execute the code in the following order* + +```python +#Start API service +redis-server +celery -A metagpt.ext.spo_api_backend.celery_app worker --loglevel=info --pool=solo +python -m metagpt.ext.spo_api_backend.spo_api + +#Launch graphical interface(You need to enter the metagpt \ ext \ spo_api_mackend \ frontend_stample directory and execute the following code) +streamlit run spo_gui.py + +``` + +This feature is built on the metaGPT project and is used to optimize prompt words +MetaGPT project address:https://github.com/FoundationAgents/MetaGPT + +The API backend functionality is created by aflyqi +https://github.com/aflyqi + +If there are any issues, you can communicate with me via email +2726132097@qq.com \ No newline at end of file diff --git a/metagpt/ext/spo_api_backend/__init__.py b/metagpt/ext/spo_api_backend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metagpt/ext/spo_api_backend/celery_app.py b/metagpt/ext/spo_api_backend/celery_app.py new file mode 100644 index 0000000000..1f845703d6 --- /dev/null +++ b/metagpt/ext/spo_api_backend/celery_app.py @@ -0,0 +1,16 @@ +# celery_app.py +from celery import Celery + +celery_app = Celery( + "spo_tasks", + broker="redis://localhost:6379/0", + backend="redis://localhost:6379/1" +) + +celery_app.conf.update( + task_track_started=True, + task_time_limit=3600, +) + +# Enable Celery to automatically scan the task module +celery_app.autodiscover_tasks(['metagpt.ext.spo_api_backend']) diff --git a/metagpt/ext/spo_api_backend/frontend_sample/spo_gui.py b/metagpt/ext/spo_api_backend/frontend_sample/spo_gui.py new file mode 100644 index 0000000000..c763ff8eec --- /dev/null +++ b/metagpt/ext/spo_api_backend/frontend_sample/spo_gui.py @@ -0,0 +1,460 @@ +import os +import uuid +import json +from pathlib import Path +from typing import Dict, Any, List +from datetime import datetime + +import requests +import streamlit as st +import yaml + +# Historical task storage path +HISTORY_FILE = Path("spo_history.json") + +############################################################################### +# ------------------------- COMPATIBILITY HELPERS ------------------------- # +############################################################################### + +def safe_rerun() -> None: + """Call `st.rerun()` if available; fall back to `st.experimental_rerun()` if not.""" + if hasattr(st, "rerun"): + st.rerun() + elif hasattr(st, "experimental_rerun"): + st.experimental_rerun() + +############################################################################### +# ---------------------------- HISTORY PERSISTENCE ------------------------ # +############################################################################### + +def load_history() -> Dict[str, Any]: + """Load historical tasks from local files""" + if HISTORY_FILE.exists(): + try: + with HISTORY_FILE.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + return {} + +def save_history(tasks: Dict[str, Any]) -> None: + """Save task history to local file""" + try: + with HISTORY_FILE.open("w", encoding="utf-8") as f: + json.dump(tasks, f, ensure_ascii=False, indent=2) + except Exception as e: + st.error(f"Failed to save history: {e}") + +############################################################################### +# ---------------------------- API HELPER CLASS --------------------------- # +############################################################################### + +class SPOClient: + """Lightweight client for the SPO FastAPI backend.""" + + def __init__(self, base_url: str = "http://localhost:8000") -> None: + self.base_url = base_url.rstrip("/") + + def start_optimization( + self, + optimization_model: str, + optimization_temp: float, + evaluation_model: str, + evaluation_temp: float, + execution_model: str, + execution_temp: float, + template_path: str, + initial_round: int = 1, + max_rounds: int = 10, + task_name: str | None = None, + ) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "optimization_model": { + "model": optimization_model, + "temperature": optimization_temp, + }, + "evaluation_model": { + "model": evaluation_model, + "temperature": evaluation_temp, + }, + "execution_model": { + "model": execution_model, + "temperature": execution_temp, + }, + "template_path": template_path, + "initial_round": int(initial_round), + "max_rounds": int(max_rounds), + } + if task_name: + payload["task_name"] = task_name + + res = requests.post(f"{self.base_url}/optimize", json=payload, timeout=30) + res.raise_for_status() + return res.json() + + def safe_get_status(self, task_id: str) -> Dict[str, Any]: + try: + res = requests.get(f"{self.base_url}/status/{task_id}", timeout=10) + res.raise_for_status() + return res.json() + except Exception as exc: # noqa: BLE001 + return {"task_id": task_id, "status": "error", "error_message": str(exc)} + +############################################################################### +# ------------------------------ TEMPLATE UTILS ---------------------------- # +############################################################################### + +def generate_template_yaml( + prompt: str, + requirements: str, + qa_list: List[Dict[str, str]], + output_dir: str | Path = r"settings", +) -> str: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + fname = output_dir / f"template_{uuid.uuid4().hex}.yaml" + + yaml_obj = { + "prompt": prompt, + "requirements": requirements, + "count": None, + "qa": qa_list, + } + + with fname.open("w", encoding="utf-8") as fp: + yaml.safe_dump(yaml_obj, fp, allow_unicode=True, sort_keys=False) + + return str(fname.resolve()) + +############################################################################### +# ---------------------------- STREAMLIT STATE ---------------------------- # +############################################################################### + +def init_session_state() -> None: + # 从历史文件加载任务 + history = load_history() + st.session_state.setdefault("tasks", history) + st.session_state.setdefault("selected_task", None) + st.session_state.setdefault("qa_buffer", []) + st.session_state.setdefault("current_view", "active") # active 或 history + +############################################################################### +# ---------------------------- SIDEBAR LAYOUT ----------------------------- # +############################################################################### + +def sidebar() -> None: + st.sidebar.title("🗂️ task management") + + # 视图切换 + view_choice = st.sidebar.radio("view mode", ["🔄 active Tasks", "📚 historical tasks"]) + st.session_state.current_view = "active" if "active" in view_choice else "history" + + tasks = st.session_state.tasks + + if st.session_state.current_view == "active": + # 显示活跃任务(运行中或待刷新的任务) + active_tasks = {tid: info for tid, info in tasks.items() + if info.get("status") not in {"completed", "failed", "error"}} + task_labels = ["➕ new task"] + [ + f"{tid[:8]} | {info.get('task_name', 'Unnamed')}" + for tid, info in active_tasks.items() + ] + choice = st.sidebar.radio("select task:", task_labels, index=0) + + if choice == "➕ new task": + st.session_state.selected_task = None + else: + sel_prefix = choice.split(" | ")[0] + for tid in active_tasks: + if tid.startswith(sel_prefix): + st.session_state.selected_task = tid + break + + with st.sidebar.expander("🔄 Refresh active tasks", expanded=False): + if st.button("🔃 Refresh all active task statuses", use_container_width=True): + client = SPOClient() + for tid, info in active_tasks.items(): + updated_info = client.safe_get_status(tid) + tasks[tid].update(updated_info) + save_history(tasks) + safe_rerun() + + else: # 历史视图 + # 显示所有任务,按创建时间排序 + sorted_tasks = sorted(tasks.items(), + key=lambda x: x[1].get('created_at', ''), reverse=True) + + if not sorted_tasks: + st.sidebar.info("There are currently no historical tasks available") + st.session_state.selected_task = None + else: + task_options = {} + for tid, info in sorted_tasks: + status_emoji = {"completed": "✅", "failed": "❌", "error": "⚠️", "running": "🔄"}.get(info.get("status"), "❓") + label = f"{status_emoji} {tid[:8]} | {info.get('task_name', 'Unnamed')}" + task_options[label] = tid + + choice = st.sidebar.selectbox("Select historical tasks:", list(task_options.keys())) + if choice: + st.session_state.selected_task = task_options[choice] + + with st.sidebar.expander("🗑️ Manage historical tasks", expanded=False): + if st.button("🔃 Refresh all task statuses", use_container_width=True): + client = SPOClient() + for tid, info in tasks.items(): + if info.get("status") not in {"completed", "failed", "error"}: + updated_info = client.safe_get_status(tid) + tasks[tid].update(updated_info) + save_history(tasks) + safe_rerun() + + st.markdown("**Delete task:**") + deletable = [tid for tid, info in tasks.items() + if info.get("status") in {"completed", "failed", "error"}] + + if deletable: + selected_for_deletion = st.multiselect( + "Select the task to delete:", + options=deletable, + format_func=lambda x: f"{x[:8]} | {tasks[x].get('task_name', 'Unnamed')}" + ) + + if selected_for_deletion and st.button("🗑️ Delete selected task", use_container_width=True): + for tid in selected_for_deletion: + tasks.pop(tid, None) + if st.session_state.selected_task in selected_for_deletion: + st.session_state.selected_task = None + save_history(tasks) + safe_rerun() + else: + st.info("There are no tasks that can be deleted") + +############################################################################### +# --------------------------- QA EDITOR HELPER ----------------------------- # +############################################################################### + +def show_qa_editor(in_form: bool) -> None: + qa_list: List[Dict[str, str]] = st.session_state.qa_buffer + + for idx, pair in enumerate(qa_list): + q, a = st.columns(2) + pair["question"] = q.text_input(f"Question {idx+1}", value=pair.get("question", ""), key=f"q_{idx}") + pair["answer"] = a.text_input(f"Answer {idx+1}", value=pair.get("answer", ""), key=f"a_{idx}") + + ctrl = st.columns(3) + if in_form: + add = ctrl[0].form_submit_button("➕ Add QA") + rm = ctrl[1].form_submit_button("➖ Remove the last item") + clr = ctrl[2].form_submit_button("🧹 Clear") + else: + add = ctrl[0].button("➕ Add QA") + rm = ctrl[1].button("➖ Remove the last item") + clr = ctrl[2].button("🧹 Clear") + + if add: + qa_list.append({"question": "", "answer": ""}) + if rm and qa_list: + qa_list.pop() + if clr and qa_list: + qa_list.clear() + +############################################################################### +# --------------------------- NEW TASK LAYOUT ----------------------------- # +############################################################################### + +def render_new_task_ui() -> None: + st.header("🆕 New optimization task") + + with st.form("task_form", clear_on_submit=False): + c1, c2 = st.columns(2) + task_name = c1.text_input("Task Name", value="my_optimization_task") + base_url = c2.text_input("API Base URL", value="http://localhost:8000") + + st.subheader("✍️ model selection") + models = [ + "Qwen1.5-72B-Chat-AWQ", + "Qwen3-32B-AWQ", + "gpt-4o", + "claude-3-5-sonnet-20240620", + "deepseek-chat", + ] + mcol = st.columns(3) + opt_model = mcol[0].selectbox("Optimization", models, index=0) + eval_model = mcol[1].selectbox("Evaluation", models, index=3) + exe_model = mcol[2].selectbox("Execution", models, index=2) + + tcol = st.columns(3) + opt_temp = tcol[0].slider("Opt Temp", 0.0, 1.0, 0.7, 0.05) + eval_temp = tcol[1].slider("Eval Temp", 0.0, 1.0, 0.3, 0.05) + exe_temp = tcol[2].slider("Exec Temp", 0.0, 1.0, 0.0, 0.05) + + st.divider() + st.subheader("📑 template") + template_path = st.text_input("template_path") + + with st.expander("🛠️ Create/Edit Template", expanded=False): + ptxt = st.text_area("Prompt", key="tmp_prompt") + rtxt = st.text_area("Requirements", key="tmp_req") + st.markdown("#### QA List") + show_qa_editor(in_form=True) + if st.form_submit_button("⏺️ generation template", use_container_width=True): + if not ptxt or not rtxt: + st.warning("Prompt and Requirements cannot be empty!") + else: + BASE_DIR = Path(__file__).resolve().parent + output_dir = BASE_DIR.parent.parent / "spo" / "settings" + tpath = generate_template_yaml(ptxt, rtxt, st.session_state.qa_buffer,output_dir) + st.success(f"Template has been generated: {tpath}") + st.session_state.template_path_value = tpath + + # autofill + if "template_path_value" in st.session_state and not template_path: + template_path = st.session_state.template_path_value + + st.subheader("🔢 Round number setting") + rcol = st.columns(2) + init_round = int(rcol[0].number_input("Initial Round", 1, 100, 1)) + max_round = int(rcol[1].number_input("Max Rounds", 1, 100, 10)) + + start_clicked = st.form_submit_button("🚀 Start optimization", use_container_width=True) + + # ---- 表单外逻辑 ---- # + if start_clicked: + if not template_path: + st.error("Please provide a template or create a template!") + st.stop() + + client = SPOClient(base_url) + with st.spinner("Starting task..."): + try: + info = client.start_optimization( + opt_model, + opt_temp, + eval_model, + eval_temp, + exe_model, + exe_temp, + template_path, + init_round, + max_round, + task_name=task_name, + ) + except Exception as e: # noqa: BLE001 + st.error(f"FAIL TO START: {e}") + st.stop() + + tid = info["task_id"] + + # 保存完整的任务信息 + task_info = { + **info, + "task_name": task_name, # 确保保存任务名称 + "base_url": base_url, + "created_at": datetime.now().isoformat(), + "config": { + "optimization_model": opt_model, + "evaluation_model": eval_model, + "execution_model": exe_model, + "optimization_temp": opt_temp, + "evaluation_temp": eval_temp, + "execution_temp": exe_temp, + "template_path": template_path, + "initial_round": init_round, + "max_rounds": max_round, + } + } + + st.session_state.tasks[tid] = task_info + save_history(st.session_state.tasks) # 保存到历史文件 + st.session_state.selected_task = tid + st.success(f"✅ Task created!Task ID: {tid}") + safe_rerun() + +############################################################################### +# ------------------------- TASK DISPLAY LAYOUT --------------------------- # +############################################################################### + +def render_task_view(tid: str) -> None: + task = st.session_state.tasks[tid] + + # 任务标题和复制按钮 + col1, col2 = st.columns([3, 1]) + with col1: + st.header(f"📊 Task: {task.get('task_name', 'Unnamed')}") + with col2: + if st.button("📋 Copy Task ID", key=f"copy_{tid}"): + st.code(tid, language=None) + st.success("Task ID has been displayed, please manually copy it") + + # 基本信息 + status = task.get("status", "unknown") + status_emoji = {"completed": "✅", "failed": "❌", "error": "⚠️", "running": "🔄"}.get(status, "❓") + st.markdown(f"**state**: {status_emoji} `{status}`") + + # 显示配置信息 + if "config" in task: + config = task["config"] + with st.expander("⚙️ Task Configuration", expanded=False): + col1, col2 = st.columns(2) + with col1: + st.markdown("**Model configuration:**") + st.markdown(f"- optimization model: `{config.get('optimization_model')}`") + st.markdown(f"- evaluation model: `{config.get('evaluation_model')}`") + st.markdown(f"- execution model: `{config.get('execution_model')}`") + with col2: + st.markdown("**parameter configuration:**") + st.markdown(f"- Initial number of rounds: `{config.get('initial_round')}`") + st.markdown(f"- Maximum number of rounds: `{config.get('max_rounds')}`") + st.markdown(f"- template file: `{Path(config.get('template_path', '')).name}`") + + # 刷新按钮 + if st.button("🔃 Refresh the status of this task", key=f"refresh_{tid}"): + client = SPOClient(task.get("base_url", "http://localhost:8000")) + updated_info = client.safe_get_status(tid) + task.update(updated_info) + save_history(st.session_state.tasks) # 保存更新后的状态 + safe_rerun() + + # 任务统计 + if status in {"completed", "failed"}: + st.markdown( + f"**用时**: {task.get('elapsed_time', 0):.2f}s | **Number of successful rounds**: {task.get('successful_rounds', 0)} / {task.get('total_rounds', 0)}" + ) + + # 结果显示 + if "results" in task and task["results"]: + st.subheader("📈 Round Details") + for res in task["results"]: + emoji = "✅" if res.get("succeed") else "❌" + with st.expander(f"Round {res['round']} {emoji}"): + st.markdown("**Prompt**") + st.code(res.get("prompt", ""), language="text") + if res.get("answers"): + st.markdown("**Q&A Result**") + for qa in res.get("answers", []): + st.markdown(f"- **{qa['question']}**: {qa['answer']}") + elif status == "running": + st.info("The task is running, refresh later to see the results...") + elif status == "error": + st.error(f"task exception: {task.get('error_message')}") + +############################################################################### +# --------------------------------- MAIN ----------------------------------- # +############################################################################### + +def main() -> None: + st.set_page_config("SPO Prompt Optimizer", layout="wide") + init_session_state() + sidebar() + + sel = st.session_state.selected_task + if sel is None: + if st.session_state.current_view == "active": + render_new_task_ui() + else: + st.info("Please select a historical task from the left to view details") + else: + render_task_view(sel) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/metagpt/ext/spo_api_backend/schemas.py b/metagpt/ext/spo_api_backend/schemas.py new file mode 100644 index 0000000000..ead5e4bee3 --- /dev/null +++ b/metagpt/ext/spo_api_backend/schemas.py @@ -0,0 +1,32 @@ +# metagpt/ext/spo_api_backend/schemas.py +from pydantic import BaseModel, Field +from typing import List, Dict, Optional +from enum import Enum + + +class RoundResult(BaseModel): + round: int + prompt: str + succeed: bool + tokens: Optional[int] = None + answers: Optional[List[Dict]] = None + + +class TaskStatus(str, Enum): + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class OptimizationResponse(BaseModel): + task_id: str + status: str + results: List[RoundResult] = [] + last_successful_prompt: Optional[str] = None + last_successful_round: Optional[int] = None + total_rounds: int = 0 + successful_rounds: int = 0 + error_message: Optional[str] = None + start_time: Optional[float] = None + end_time: Optional[float] = None + elapsed_time: Optional[float] = None diff --git a/metagpt/ext/spo_api_backend/spo_api.py b/metagpt/ext/spo_api_backend/spo_api.py new file mode 100644 index 0000000000..49ed86d58b --- /dev/null +++ b/metagpt/ext/spo_api_backend/spo_api.py @@ -0,0 +1,361 @@ +import asyncio +import uuid +import time +import os +import sys +from pathlib import Path +from typing import Dict, List, Optional +from enum import Enum + +from fastapi import FastAPI, HTTPException, BackgroundTasks +from pydantic import BaseModel, Field +import yaml +from loguru import logger +import nest_asyncio +nest_asyncio.apply() +from fastapi.background import BackgroundTasks +import redis +#from .tasks import run_optimization_celery +import json +from metagpt.ext.spo_api_backend.tasks import run_optimization_celery +from metagpt.ext.spo_api_backend.schemas import OptimizationResponse, TaskStatus, RoundResult + +r = redis.Redis(decode_responses=True) + + + + +def save_task_to_redis(task: OptimizationResponse): + r.set(f"task:{task.task_id}", task.json()) + + +def load_task_from_redis(task_id: str) -> Optional[OptimizationResponse]: + raw = r.get(f"task:{task_id}") + if not raw: + return None + return OptimizationResponse.parse_raw(raw) + + + +# Startup script for solving path problems +def setup_metagpt_environment(): + """Set up MetaGPT environment""" + current_file = Path(__file__).absolute() + + # Try multiple possible MetaGPT root directory locations + possible_roots = [ + current_file.parent.parent.parent, # if in metagpt/ext/spo/ + current_file.parent.parent.parent.parent, # If deeper + Path.cwd(), # 当Previous Work Catalog + Path.cwd().parent, # parent directory + ] + + metagpt_root = None + for root in possible_roots: + if (root / "metagpt" / "__init__.py").exists(): + metagpt_root = root + break + + if metagpt_root is None: + # If not found, use the parent directory of the current directory + metagpt_root = current_file.parent.parent.parent + print(f"Warning: Could not auto-detect MetaGPT root, using: {metagpt_root}") + + # 添加到 Python 路径 + if str(metagpt_root) not in sys.path: + sys.path.insert(0, str(metagpt_root)) + + # 设置环境变量 + os.environ['METAGPT_ROOT'] = str(metagpt_root) + + return metagpt_root + +# 设置环境 +METAGPT_ROOT = setup_metagpt_environment() + +# 尝试导入 MetaGPT 模块 +try: + from metagpt.ext.spo.components.optimizer import PromptOptimizer + from metagpt.ext.spo.utils.llm_client import SPO_LLM + print("✓ Successfully imported MetaGPT modules") +except ImportError as e: + print(f"❌ Failed to import MetaGPT modules: {e}") + print(f"Current directory: {Path.cwd()}") + print(f"Script directory: {Path(__file__).parent}") + print(f"MetaGPT root: {METAGPT_ROOT}") + print(f"Python path: {sys.path[:3]}...") # 只显示前几个路径 + + # 提供详细的错误信息和解决方案 + print("\n🔧 Troubleshooting steps:") + print("1. Make sure you're running this script from the correct directory") + print("2. Check if MetaGPT is properly installed") + print("3. Try running: pip install -e . (from MetaGPT root directory)") + print("4. Or set PYTHONPATH manually: export PYTHONPATH=/path/to/MetaGPT:$PYTHONPATH") + + raise SystemExit(1) + + +# Pydantic Models +class ModelConfig(BaseModel): + model: str + temperature: float = Field(ge=0.0, le=1.0) + + +class OptimizationRequest(BaseModel): + optimization_model: ModelConfig + evaluation_model: ModelConfig + execution_model: ModelConfig + template_path: str + initial_round: int = Field(ge=1, le=100, default=1) + max_rounds: int = Field(ge=1, le=100, default=10) + task_name: Optional[str] = None + + +class RoundResult(BaseModel): + round: int + prompt: str + succeed: bool + tokens: Optional[int] = None + answers: Optional[List[Dict]] = None + + +class OptimizationResponse(BaseModel): + task_id: str + status: str # "running", "completed", "failed" + results: List[RoundResult] = [] + last_successful_prompt: Optional[str] = None + last_successful_round: Optional[int] = None + total_rounds: int = 0 + successful_rounds: int = 0 + error_message: Optional[str] = None + start_time: Optional[float] = None + end_time: Optional[float] = None + elapsed_time: Optional[float] = None + + +class TaskStatus(str, Enum): + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +# Global storage for tasks (in production, use Redis or database) +task_storage: Dict[str, OptimizationResponse] = {} + + +# FastAPI app +app = FastAPI( + title="SPO API - Self-Supervised Prompt Optimization", + description="API for running prompt optimization tasks concurrently", + version="1.0.0" +) + + +def load_yaml_template(template_path: Path) -> Dict: + """Load YAML template from file path""" + if template_path.exists(): + with open(template_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + raise FileNotFoundError(f"Template file not found: {template_path}") + + +async def run_optimization_task(task_id: str, request: OptimizationRequest): + """Run the optimization task asynchronously""" + try: + # Update task status to running + task = load_task_from_redis(task_id) + if task: + task.status = TaskStatus.RUNNING + task.start_time = time.time() + save_task_to_redis(task) + + logger.info(f"Starting optimization task {task_id}") + + # Validate template path + template_path = Path(request.template_path) + if not template_path.exists(): + raise FileNotFoundError(f"Template file not found: {template_path}") + + # Initialize LLM + SPO_LLM.initialize( + optimize_kwargs={ + "model": request.optimization_model.model, + "temperature": request.optimization_model.temperature + }, + evaluate_kwargs={ + "model": request.evaluation_model.model, + "temperature": request.evaluation_model.temperature + }, + execute_kwargs={ + "model": request.execution_model.model, + "temperature": request.execution_model.temperature + }, + ) + + # Extract template name from path + template_name = template_path.stem + + # Create workspace directory + workspace_path = f"workspace_{task_id}" + workspace_dir = Path(workspace_path) + workspace_dir.mkdir(exist_ok=True) + + # Create optimizer instance + optimizer = PromptOptimizer( + optimized_path=workspace_path, + initial_round=request.initial_round, + max_rounds=request.max_rounds, + template=template_path.name, + name=request.task_name or template_name, + ) + + # Copy template to optimizer's settings directory if needed + optimizer_template_path = optimizer.root_path / "settings" / template_path.name + optimizer_template_path.parent.mkdir(parents=True, exist_ok=True) + + if not optimizer_template_path.exists(): + import shutil + shutil.copy2(template_path, optimizer_template_path) + + # Run optimization + if hasattr(optimizer, 'aoptimize'): + await optimizer.aoptimize() # 使用异步版本 + else: + # 如果异步版本不存在,在后台线程中运行同步版本 + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, optimizer.optimize) + + # Load results + prompt_path = optimizer.root_path / "prompts" + result_data = optimizer.data_utils.load_results(prompt_path) + + # Process results + results = [] + last_successful_prompt = None + last_successful_round = None + successful_rounds = 0 + + for result in result_data: + round_result = RoundResult( + round=result["round"], + prompt=result["prompt"], + succeed=result["succeed"], + tokens=result.get("tokens"), + answers=result.get("answers", []) + ) + results.append(round_result) + + if result["succeed"]: + successful_rounds += 1 + last_successful_prompt = result["prompt"] + last_successful_round = result["round"] + + # Update task with results + end_time = time.time() + task = load_task_from_redis(task_id) + if task: + task.status = TaskStatus.COMPLETED + task.results = results + task.last_successful_prompt = last_successful_prompt + task.last_successful_round = last_successful_round + task.total_rounds = len(results) + task.successful_rounds = successful_rounds + task.end_time = end_time + task.elapsed_time = end_time - task.start_time + save_task_to_redis(task) + + logger.info(f"Optimization task {task_id} completed successfully") + + except Exception as e: + logger.error(f"Optimization task {task_id} failed: {str(e)}") + task = load_task_from_redis(task_id) + if task: + task.status = TaskStatus.FAILED + task.error_message = str(e) + task.end_time = time.time() + if task.start_time: + task.elapsed_time = task.end_time - task.start_time + save_task_to_redis(task) + + +@app.post("/optimize", response_model=OptimizationResponse) +async def start_optimization(request: OptimizationRequest): + task_id = str(uuid.uuid4()) + + task = OptimizationResponse( + task_id=task_id, + status=TaskStatus.RUNNING + ) + + save_task_to_redis(task) + + run_optimization_celery.delay(task_id, request.dict()) + + return task + + + +@app.get("/status/{task_id}", response_model=OptimizationResponse) +async def get_task_status(task_id: str): + """Get the status of an optimization task""" + + task = load_task_from_redis(task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + return task + + + +@app.get("/tasks", response_model=List[OptimizationResponse]) +async def list_all_tasks(): + """List all optimization tasks""" + keys = r.keys("task:*") + tasks = [OptimizationResponse.parse_raw(r.get(k)) for k in keys] + return tasks + + +@app.delete("/tasks/{task_id}") +async def delete_task(task_id: str): + if not r.exists(f"task:{task_id}"): + raise HTTPException(status_code=404, detail="Task not found") + r.delete(f"task:{task_id}") + return {"message": f"Task {task_id} deleted successfully"} + +@app.get("/health") +async def health_check(): + """Health check endpoint""" + return { + "status": "healthy", + "metagpt_root": str(METAGPT_ROOT), + "current_dir": str(Path.cwd()) + } + + +@app.get("/debug") +async def debug_info(): + """Debug information endpoint""" + return { + "current_directory": str(Path.cwd()), + "script_directory": str(Path(__file__).parent), + "metagpt_root": str(METAGPT_ROOT), + "python_path": sys.path[:5], # 只显示前5个路径 + "environment_vars": { + "METAGPT_ROOT": os.environ.get("METAGPT_ROOT"), + "PYTHONPATH": os.environ.get("PYTHONPATH", "Not set") + } + } + + +if __name__ == "__main__": + import uvicorn + + print(f"🚀 Starting SPO API server...") + print(f"📁 MetaGPT root: {METAGPT_ROOT}") + print(f"📍 Current directory: {Path.cwd()}") + print(f"🌐 API docs will be available at: http://localhost:8000/docs") + + # 创建日志目录 + log_dir = METAGPT_ROOT / "logs" + log_dir.mkdir(exist_ok=True) + + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/metagpt/ext/spo_api_backend/tasks.py b/metagpt/ext/spo_api_backend/tasks.py new file mode 100644 index 0000000000..50aa383628 --- /dev/null +++ b/metagpt/ext/spo_api_backend/tasks.py @@ -0,0 +1,50 @@ +from .celery_app import celery_app +import asyncio +import logging +import redis +import json +from metagpt.ext.spo_api_backend.schemas import OptimizationResponse + +r = redis.Redis() + +def redis_key(task_id: str) -> str: + return f"task:{task_id}" + +def save_task_to_redis(task: OptimizationResponse): + r.set(redis_key(task.task_id), task.json()) + +def load_task_from_redis(task_id: str) -> OptimizationResponse | None: + raw = r.get(redis_key(task_id)) + if raw: + return OptimizationResponse.parse_raw(raw) + return None + + +@celery_app.task(bind=True) +def run_optimization_celery(self, task_id: str, request_dict: dict): + try: + from metagpt.ext.spo_api_backend.spo_api import run_optimization_task, OptimizationRequest + + request = OptimizationRequest(**request_dict) + asyncio.run(run_optimization_task(task_id, request)) + + return {"status": "success", "task_id": task_id} + + except Exception as e: + logging.exception(f"[Celery Task Failed] Task {task_id}: {e}") + + try: + task = load_task_from_redis(task_id) + if task: + task.status = "failed" + task.error_message = str(e) + save_task_to_redis(task) + except Exception: + logging.warning(f"Couldn't update task in Redis for {task_id}") + + return {"status": "failed", "task_id": task_id, "error": str(e)} + + +@celery_app.task +def ping(): + return "pong"