diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2bb16e85..57b0328c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -42,7 +42,7 @@ jobs: run: | uv pip install huggingface_hub[cli] huggingface-cli download --repo-type dataset The-OpenROAD-Project/ORAssistant_RAG_Dataset --include source_list.json --local-dir data/ - export GOOGLE_API_KEY="dummy-unit-test-key" + cp .env.test .env make test - name: Build Docker images diff --git a/backend/.env.example b/backend/.env.example index e7067470..d5729375 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -50,10 +50,16 @@ FAST_MODE=false # Debug mode for development DEBUG=false +# Enable MCP (Model Context Protocol) tools +ENABLE_MCP=false + # MCP Server Configuration # Path to OpenROAD Flow Scripts directory ORFS_DIR={{PATH_TO_ORFS_DIR}} +# Disable GUI commands (set to true for headless environments) +DISABLE_GUI=true + # Repository commit hashes for documentation building OR_REPO_COMMIT=ffc5760f2df639cd184c40ceba253c7e02a006d5 ORFS_REPO_COMMIT={{ORFS_REPO_COMMIT}} diff --git a/backend/.env.test b/backend/.env.test new file mode 100644 index 00000000..ea5276cb --- /dev/null +++ b/backend/.env.test @@ -0,0 +1,32 @@ +# Test Environment Variables for CI +# This file contains dummy values for running unit tests + +# Google API (dummy key for tests) +GOOGLE_API_KEY=dummy-unit-test-key + +# Embedding Configuration +EMBEDDINGS_TYPE=HF +HF_EMBEDDINGS=thenlper/gte-large +HF_RERANKER=BAAI/bge-reranker-base +GOOGLE_EMBEDDINGS=text-embedding-004 + +# LLM Configuration +LLM_MODEL=gemini +LLM_TEMP=1 +GOOGLE_GEMINI=2.0_flash +OLLAMA_MODEL= + +# System Configuration +USE_CUDA=false +SEARCH_K=5 +CHUNK_SIZE=2000 +CHUNK_OVERLAP=200 + +# Optional settings +FAISS_DB_PATH=./.faissdb/faiss_index +TOKENIZERS_PARALLELISM=false +LOGLEVEL=INFO +FAST_MODE=false +DEBUG=false +ENABLE_MCP=false +DISABLE_GUI=true diff --git a/backend/mcp_server.py b/backend/mcp_server.py new file mode 100644 index 00000000..a8f1202a --- /dev/null +++ b/backend/mcp_server.py @@ -0,0 +1,12 @@ +import os +import logging + +from src.openroad_mcp.server.orfs.orfs_server import ORFSServer + +logging.basicConfig( + level=os.environ.get("LOGLEVEL", "INFO").upper(), + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s\n", +) + +server = ORFSServer() +server.mcp.run(transport="http", host="127.0.0.1", port=3001) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index aaca8dfa..82733d11 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -52,3 +52,15 @@ dev = [ [tool.pytest.ini_options] pythonpath = ["."] +markers = [ + "unit: marks tests as unit tests (fast, isolated)", + "integration: marks tests as integration tests (slower, may require external resources)", +] +asyncio_mode = "auto" + +[tool.mypy] +exclude = [ + "^tests/", +] +python_version = "3.13" +disallow_untyped_defs = true diff --git a/backend/src/agents/retriever_graph.py b/backend/src/agents/retriever_graph.py index a7440355..cfd8d809 100644 --- a/backend/src/agents/retriever_graph.py +++ b/backend/src/agents/retriever_graph.py @@ -110,7 +110,7 @@ def fork_route(self, state: AgentState) -> str: if not self.enable_mcp: tmp = "rag_agent" else: - tmp = state["agent_type"][0] + tmp = "mcp_agent" return tmp def initialize(self) -> None: diff --git a/backend/src/agents/retriever_mcp.py b/backend/src/agents/retriever_mcp.py index 7059bbfc..d2e8b4fa 100644 --- a/backend/src/agents/retriever_mcp.py +++ b/backend/src/agents/retriever_mcp.py @@ -75,17 +75,16 @@ def mcp_tool_node(self, state: AgentState) -> dict[str, list[Any]]: logging.info(tool_call["args"]) try: observation = asyncio.run(tool.ainvoke(tool_call["args"])) + result.append(observation) except ToolException as e: + error_msg = f"Tool '{tool_call['name']}' failed: {str(e)}" logging.error(f"ToolException during {tool_call['name']}: {e}") - observation = None + result.append(error_msg) except Exception as e: + error_msg = f"Tool '{tool_call['name']}' encountered an error: {str(e)}" logging.error(f"Unexpected error during {tool_call['name']}: {e}") - observation = None + result.append(error_msg) - if observation: - result.append(observation) - else: - result.append("no return") logging.info("DONE") logging.info(result) return {"messages": result} diff --git a/backend/src/api/routers/graphs.py b/backend/src/api/routers/graphs.py index 4c443621..166d5c6b 100644 --- a/backend/src/api/routers/graphs.py +++ b/backend/src/api/routers/graphs.py @@ -278,7 +278,7 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse: return ChatResponse(**response) -async def get_response_stream(user_input: UserInput): +async def get_response_stream(user_input: UserInput): # type: ignore[no-untyped-def] user_question = user_input.query inputs = { @@ -316,7 +316,7 @@ async def get_response_stream(user_input: UserInput): @router.post("/agent-retriever/stream", response_class=StreamingResponse) -async def get_agent_response_streaming(user_input: UserInput): +async def get_agent_response_streaming(user_input: UserInput): # type: ignore[no-untyped-def] return StreamingResponse( get_response_stream(user_input), media_type="text/event-stream" ) diff --git a/backend/src/openroad_mcp/client/client.py b/backend/src/openroad_mcp/client/client.py index a12e6b60..c07d34b9 100644 --- a/backend/src/openroad_mcp/client/client.py +++ b/backend/src/openroad_mcp/client/client.py @@ -1,13 +1,15 @@ import asyncio import logging +from typing import Any +from datetime import timedelta from langchain_mcp_adapters.client import MultiServerMCPClient # type: ignore MCP_SERVER_URL = "http://localhost:3001/mcp/" -_tools_cache = None +_tools_cache: Any = None -async def get_tools_async(): +async def get_tools_async() -> Any: """Get MCP tools asynchronously""" global _tools_cache if _tools_cache is None: @@ -17,6 +19,11 @@ async def get_tools_async(): "orfs_cmd": { "transport": "streamable_http", "url": MCP_SERVER_URL, + # TODO: remove this once tools are async! + # HTTP request timeout - increase for long make commands + "timeout": timedelta(hours=2), + # SSE read timeout - how long to wait for events + "sse_read_timeout": timedelta(hours=2), }, } ) @@ -28,7 +35,7 @@ async def get_tools_async(): return _tools_cache -def get_tools(): +def get_tools() -> Any: """Get MCP tools synchronously""" try: return asyncio.run(get_tools_async()) diff --git a/backend/src/openroad_mcp/server/orfs/orfs_base.py b/backend/src/openroad_mcp/server/orfs/orfs_base.py new file mode 100644 index 00000000..fece37af --- /dev/null +++ b/backend/src/openroad_mcp/server/orfs/orfs_base.py @@ -0,0 +1,195 @@ +import os +import subprocess +import logging +import shlex +from src.openroad_mcp.server.orfs.orfs_tools import ORFS + + +def _should_skip_gui() -> bool: + """Check if GUI commands should be skipped based on environment variable.""" + return os.getenv("DISABLE_GUI", "false").lower() in ("true", "1", "yes") + + +class ORFSBase(ORFS): + def _get_platforms_impl(self) -> str: + """Internal implementation of get_platforms""" + # TODO: scrape platforms instead of serving only default sky130 + assert ORFS.server is not None + ORFS.server.platform = "sky130hd" + return ORFS.server.platform + + def _get_designs_impl(self) -> str: + """Internal implementation of get_designs""" + # TODO: scrape designs instead of default riscv + assert ORFS.server is not None + ORFS.server.design = "riscv32i" + return ORFS.server.design + + def _check_configuration(self) -> None: + assert ORFS.server is not None + if not ORFS.server.platform: + ORFS.server._get_platforms_impl() + logging.info(ORFS.server.platform) + + if not ORFS.server.design: + ORFS.server._get_designs_impl() + logging.info(ORFS.server.design) + + def _command(self, cmd: str) -> None: + assert ORFS.server is not None + working = os.getcwd() + os.chdir(ORFS.server.flow_dir) + + make = f"make DESIGN_CONFIG={ORFS.server.flow_dir}/designs/{ORFS.server.platform}/{ORFS.server.design}/config.mk" + logging.info(cmd) + build_command = f"{make} {cmd}" + ORFS.server._run_command(build_command) + + os.chdir(working) + + def _run_command(self, cmd: str) -> None: + assert ORFS.server is not None + logging.info("start command") + + process = subprocess.Popen( + shlex.split(cmd), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=1, # Line-buffered + universal_newlines=True, # Text mode + env=ORFS.server.env, + ) + + if process.stdout: + for line in process.stdout: + logging.info(line.rstrip()) + + process.wait() + if process.returncode != 0: + logging.error(f"Command exited with return code {process.returncode}") + raise subprocess.CalledProcessError(process.returncode, cmd) + + ### mcp tool section ### + + @staticmethod + @ORFS.mcp.tool + def get_platforms() -> str: + """call get platforms to display possible platforms to run through flow""" + assert ORFS.server is not None + return ORFS.server._get_platforms_impl() + + @staticmethod + @ORFS.mcp.tool + def get_designs() -> str: + """call get designs to display possible designs to run through flow""" + assert ORFS.server is not None + return ORFS.server._get_designs_impl() + + @staticmethod + @ORFS.mcp.tool + def make(cmd: str) -> str: + """Execute a makefile target for OpenROAD-flow-scripts. + + Common commands: + - "clean" - Remove all build artifacts and start fresh + - "synth" - Run synthesis + - "place" - Run placement + - "route" - Run routing + - "final" - Generate final reports + + Use this for any makefile target not covered by step/jump commands. + """ + assert ORFS.server is not None + ORFS.server._check_configuration() + ORFS.server._command(cmd) + + return f"finished {cmd}" + + @staticmethod + @ORFS.mcp.tool + def get_stage_names() -> str: + """get stage names for possible states this mcp server can be in the chip design pipeline""" + assert ORFS.server is not None + stage_names = [_.info() for _ in ORFS.server.stages.values()] + logging.info(stage_names) # in server process + # for chatbot output + result = "" + for _ in stage_names: + result += f"{_}\n" + return result + + @staticmethod + @ORFS.mcp.tool + def jump(stage: str) -> str: + """Jump directly to a specific stage in the chip design pipeline. + + Valid stage names (MUST use exact names): + - "synth" - Synthesis + - "floorplan" - Floorplan + - "place" - Placement + - "cts" - Clock Tree Synthesis + - "route" - Routing + - "final" - Final Report + + Use get_stage_names() to see all available stages. + """ + assert ORFS.server is not None + ORFS.server._check_configuration() + + stage_names = [_.info() for _ in ORFS.server.stages.values()] + logging.info(stage_names) + if stage in stage_names: + logging.info(stage) + ORFS.server.cur_stage = ORFS.server.stage_index[stage] + + ORFS.server._command(stage) + + # Open GUI if not disabled + if not _should_skip_gui(): + try: + ORFS.server._command(f"gui_{stage}") + except subprocess.CalledProcessError as e: + logging.warning(f"GUI command failed: {e}") + else: + logging.info("Skipping GUI command (DISABLE_GUI=true)") + + return f"finished {stage}" + else: + logging.info("jump unsuccessful..") + return f"aborted {stage}" + + @staticmethod + @ORFS.mcp.tool + def step() -> str: + """Progress to the next stage in the chip design pipeline (synthesis -> floorplan -> placement -> CTS -> routing -> final report)""" + assert ORFS.server is not None + + def make_keyword() -> str: + assert ORFS.server is not None + logging.info(ORFS.server.cur_stage) + if ORFS.server.cur_stage <= len(ORFS.server.stages) - 2: + ORFS.server.cur_stage += 1 + else: + logging.info("end of pipeline..") + return ORFS.server.stages[ORFS.server.cur_stage].info() + + ORFS.server._check_configuration() + + command = make_keyword() + ORFS.server._command(command) + + # Open GUI if not disabled + if not _should_skip_gui(): + try: + ORFS.server._command(f"gui_{command}") + except subprocess.CalledProcessError as e: + logging.warning(f"GUI command failed: {e}") + else: + logging.info("Skipping GUI command (DISABLE_GUI=true)") + + return f"finished {command}" + + # TODO: scrape all makefile keywords and make into mcp tool + @staticmethod + def get_all_keywords() -> None: + pass diff --git a/backend/src/openroad_mcp/server/orfs/orfs_make.py b/backend/src/openroad_mcp/server/orfs/orfs_make.py new file mode 100644 index 00000000..5048ebdf --- /dev/null +++ b/backend/src/openroad_mcp/server/orfs/orfs_make.py @@ -0,0 +1,157 @@ +import logging +from src.openroad_mcp.server.orfs.orfs_tools import ORFS +from langchain.prompts import ChatPromptTemplate +from langchain_core.output_parsers import JsonOutputParser +from src.prompts.prompt_templates import env_prompt_template + + +class ORFSMake(ORFS): + """Handles dynamic Makefile generation and environment configuration for ORFS. + + This class manages the creation of dynamic Makefile configurations and environment + variables for OpenROAD-flow-scripts builds. It provides MCP tools for generating + config.mk files and extracting environment variables from documentation using + RAG (Retrieval-Augmented Generation). + """ + + def _get_default_makefile(self) -> None: + """Set the makefile pointer to the default config.mk location.""" + assert ORFS.server is not None + ORFS.server.makefile_pointer = f"{ORFS.server.flow_dir}/designs/{ORFS.server.platform}/{ORFS.server.design}/config.mk" + + def _get_makefile(self) -> None: + """Retrieve the current makefile pointer path.""" + assert ORFS.server is not None + if ORFS.server.makefile_pointer: + return ORFS.server.makefile_pointer + + def _get_default_env(self) -> None: + """Initialize default environment variables for ORFS build configuration.""" + # TODO: categorize into ORFSEnv TypedDict + assert ORFS.server is not None + ORFS.server.orfs_env.update( + { + "PLATFORM": f"{ORFS.server.platform}", + "DESIGN_NAME": f"{ORFS.server.design}", + "DESIGN_NICKNAME": f"{ORFS.server.design}", + "VERILOG_FILES": "$(sort $(wildcard ./designs/src/$(DESIGN_NICKNAME)/*.v))", + "SDC_FILE": "./designs/$(PLATFORM)/$(DESIGN_NICKNAME)/constraint.sdc", + "CORE_UTILIZATION": "50", + "PLACE_DENSITY": "50", + } + ) + + @staticmethod + @ORFS.mcp.tool + def create_dynamic_makefile(cmd: str) -> str: + """Create a dynamic Makefile configuration for the current ORFS design. + + Generates a dynamic_config.mk file containing exported environment variables + for the OpenROAD-flow-scripts build system. This enables custom build + configurations without modifying the default config.mk files. + + The tool ensures a design and platform are selected (prompting initialization + if needed), then creates a Makefile with all current environment variables + exported in the format: `export VAR_NAME = value` + + Args: + cmd: User command or query string (currently unused - may be used for + future query-based configuration). + + Returns: + str: A formatted string listing all environment variables and their values, + one per line in the format "VAR_NAME: value". Returns "no env vars" + if no environment variables are configured. + + Side Effects: + - Sets ORFS.server.dynamic_makefile to True + - Creates/overwrites dynamic_config.mk in the design directory + - Updates ORFS.server.makefile_pointer to the dynamic config path + - May trigger design/platform initialization if not already set + - Logs warning if no design/platform is selected + + Example: + >>> create_dynamic_makefile("setup build config") + "PLATFORM: asap7\\nDESIGN_NAME: gcd\\nCORE_UTILIZATION: 50\\n..." + + Note: + File is written to: {flow_dir}/designs/{platform}/{design}/dynamic_config.mk + """ + assert ORFS.server is not None + if not (ORFS.server.design and ORFS.server.platform): + logging.warning("no custom design/platform selected!") + ORFS.server._get_designs_impl() + ORFS.server._get_platforms_impl() + ORFS.server._get_default_env() + else: + pass + + ORFS.server.dynamic_makefile = True + ORFS.server.makefile_pointer = f"{ORFS.server.flow_dir}/designs/{ORFS.server.platform}/{ORFS.server.design}/dynamic_config.mk" + with open(f"{ORFS.server.makefile_pointer}", "w") as f: + for key in ORFS.server.orfs_env.keys(): + f.write(f"export {key} = {ORFS.server.orfs_env[key]}\n") + result = "" + for key in ORFS.server.orfs_env.keys(): + result += f"{key}: {ORFS.server.orfs_env[key]}\n" + if result: + return result + else: + return "no env vars" + + @staticmethod + @ORFS.mcp.tool + def get_env_vars(cmd: str) -> str: + """Extract environment variables from ORFS documentation using RAG. + + Uses Retrieval-Augmented Generation (RAG) to search ORFS documentation + for relevant environment variables based on a user query, then uses an + LLM to parse and extract them in Makefile export format. The extracted + variables are automatically added to the current environment configuration. + + This tool is useful for discovering and applying environment variables + from documentation without manual lookup. + + Args: + cmd: Natural language query describing what environment variables + are needed (e.g., "variables for clock period and frequency"). + + Returns: + str: Always returns "done env" to indicate completion. Check logs + for detailed information about retrieved variables. + + Side Effects: + - Initializes design and platform if not already set + - Queries RAG system for relevant documentation + - Invokes LLM chain to parse environment variables from context + - Updates ORFS.server.orfs_env with extracted variables + - Logs retrieved context, parsed variables, and variable types + + Example: + >>> get_env_vars("get clock period and frequency variables") + "done env" + # ORFS.server.orfs_env now updated with CLOCK_PERIOD, CLOCK_FREQ, etc. + + Note: + The LLM is prompted to output only valid Makefile export statements + in the format: `export {env_name} = {env_value}` + """ + assert ORFS.server is not None + assert ORFS.llm is not None + ORFS.server._get_designs_impl() + ORFS.server._get_platforms_impl() + ORFS.server._get_default_env() + result = ORFS.server.retrieve_general(cmd) + logging.info(ORFS.llm) + env_chain = ( + ChatPromptTemplate.from_template(env_prompt_template) + | ORFS.llm + | JsonOutputParser() + ) + string = "Only output environment variables in the following format:\n export {env_name} = {env_value}\n" + ans = env_chain.invoke({"context": result[0], "question": string + cmd}) + logging.info(result[0]) + logging.info(ans) + ORFS.server.orfs_env.update(ans) + logging.info(type(ans)) + return "done env" diff --git a/backend/src/openroad_mcp/server/orfs/orfs_rag.py b/backend/src/openroad_mcp/server/orfs/orfs_rag.py new file mode 100644 index 00000000..98b3a206 --- /dev/null +++ b/backend/src/openroad_mcp/server/orfs/orfs_rag.py @@ -0,0 +1,320 @@ +import os +import logging +from src.openroad_mcp.server.orfs.orfs_tools import ORFS + +from typing import Tuple +from langchain_ollama import ChatOllama +from langchain_google_vertexai import ChatVertexAI +from src.chains.hybrid_retriever_chain import HybridRetrieverChain +from src.tools.format_docs import format_docs + + +class ORFSRag(ORFS): + required_env_vars = [ + "USE_CUDA", + "LLM_TEMP", + "HF_EMBEDDINGS", + "HF_RERANKER", + "LLM_MODEL", + ] + hf_reranker: str = str(os.getenv("HF_RERANKER")) + reranking_model_name = hf_reranker + embeddings_type: str = str(os.getenv("EMBEDDINGS_TYPE")) + use_cuda: bool = False + if str(os.getenv("USE_CUDA")).lower() in ("true"): + use_cuda = True + + if embeddings_type == "HF": + embeddings_model_name = str(os.getenv("HF_EMBEDDINGS")) + elif embeddings_type == "GOOGLE_GENAI" or embeddings_type == "GOOGLE_VERTEXAI": + embeddings_model_name = str(os.getenv("GOOGLE_EMBEDDINGS")) + else: + raise ValueError( + "EMBEDDINGS_TYPE environment variable must be set to 'HF', 'GOOGLE_GENAI', or 'GOOGLE_VERTEXAI'." + ) + + embeddings_config = {"type": embeddings_type, "name": embeddings_model_name} + fast_mode: bool = False + markdown_docs_map = { + "general": [ + "./data/markdown/OR_docs", + "./data/markdown/ORFS_docs", + "./data/markdown/gh_discussions", + "./data/markdown/manpages/man1", + "./data/markdown/manpages/man2", + "./data/markdown/OpenSTA_docs", + ], + "install": [ + "./data/markdown/ORFS_docs/installation", + "./data/markdown/OR_docs/installation", + "./data/markdown/gh_discussions/Build", + "./data/markdown/gh_discussions/Installation", + "./data/markdown/OpenSTA_docs", + ], + "commands": [ + "./data/markdown/OR_docs/tools", + "./data/markdown/ORFS_docs/general", + "./data/markdown/gh_discussions/Query", + "./data/markdown/gh_discussions/Runtime", + "./data/markdown/gh_discussions/Documentation", + "./data/markdown/manpages/man1", + "./data/markdown/manpages/man2", + "./data/markdown/OpenSTA_docs", + ], + "errinfo": [ + "./data/markdown/manpages/man3", + "./data/markdown/gh_discussions/Bug", + ], + } + search_k = int(os.getenv("SEARCH_K", 10)) + chunk_size = int(os.getenv("CHUNK_SIZE", 4000)) + # llm: ChatGoogleGenerativeAI | ChatVertexAI | ChatOllama + + llm_temp_str = os.getenv("LLM_TEMP") + if llm_temp_str is not None: + llm_temp = float(llm_temp_str) + + if os.getenv("LLM_MODEL") == "ollama": + model_name = str(os.getenv("OLLAMA_MODEL")) + ORFS.llm = ChatOllama(model=model_name, temperature=llm_temp) + logging.info(ORFS.llm) + elif os.getenv("LLM_MODEL") == "gemini": + gemini_model = os.getenv("GOOGLE_GEMINI") + if gemini_model in {"1_pro", "1.5_flash", "1.5_pro"}: + raise ValueError( + f"The selected Gemini model '{gemini_model}' (version 1.0–1.5) is disabled. " + "Please upgrade to version 2.0 or higher (e.g., 2.0_flash, 2.5_flash, 2.5_pro)." + ) + elif gemini_model == "2.0_flash": + ORFS.llm = ChatVertexAI(model_name="gemini-2.0-flash", temperature=llm_temp) + elif gemini_model == "2.5_flash": + ORFS.llm = ChatVertexAI(model_name="gemini-2.5-flash", temperature=llm_temp) + elif gemini_model == "2.5_pro": + ORFS.llm = ChatVertexAI(model_name="gemini-2.5-pro", temperature=llm_temp) + else: + raise ValueError( + f"GOOGLE_GEMINI environment variable '{gemini_model}' not set to a valid value. " + "Valid options: 2.0_flash, 2.5_flash, 2.5_pro" + ) + else: + raise ValueError("LLM_MODEL environment variable not set to a valid value.") + ### + # TODO: remove fast_mode or keep? + general_retriever_chain = HybridRetrieverChain( + embeddings_config=embeddings_config, + reranking_model_name=reranking_model_name, + use_cuda=use_cuda, + # html_docs_path=[] if fast_mode else ["./data/html/or_website/"], + # markdown_docs_path=fastmode_docs_map["general"] + # if fast_mode + # else markdown_docs_map["general"], + # other_docs_path=[] if fast_mode else ["./data/pdf"], + html_docs_path=["./data/html/or_website/"], + markdown_docs_path=markdown_docs_map["general"], + other_docs_path=["./data/pdf"], + weights=[0.6, 0.2, 0.2], + contextual_rerank=True, + search_k=search_k, + chunk_size=chunk_size, + ) + general_retriever_chain.create_hybrid_retriever() + ORFS.general_retriever = general_retriever_chain.retriever + + install_retriever_chain = HybridRetrieverChain( + embeddings_config=embeddings_config, + reranking_model_name=reranking_model_name, + use_cuda=use_cuda, + # markdown_docs_path=fastmode_docs_map["install"] + # if fast_mode + # else markdown_docs_map["install"], + markdown_docs_path=markdown_docs_map["install"], + weights=[0.6, 0.2, 0.2], + contextual_rerank=True, + search_k=search_k, + chunk_size=chunk_size, + ) + install_retriever_chain.create_hybrid_retriever() + ORFS.install_retriever = install_retriever_chain.retriever + + commands_retriever_chain = HybridRetrieverChain( + embeddings_config=embeddings_config, + reranking_model_name=reranking_model_name, + use_cuda=use_cuda, + # markdown_docs_path=fastmode_docs_map["commands"] + # if fast_mode + # else markdown_docs_map["commands"], + # other_docs_path=[] if fast_mode else ["./data/pdf"], + markdown_docs_path=markdown_docs_map["commands"], + other_docs_path=["./data/pdf"], + weights=[0.6, 0.2, 0.2], + contextual_rerank=True, + search_k=search_k, + chunk_size=chunk_size, + ) + commands_retriever_chain.create_hybrid_retriever() + ORFS.commands_retriever = commands_retriever_chain.retriever + + yosys_rtdocs_retriever_chain = HybridRetrieverChain( + embeddings_config=embeddings_config, + reranking_model_name=reranking_model_name, + use_cuda=use_cuda, + # html_docs_path=fastmode_docs_map["yosys"] + # if fast_mode + # else ["./data/html/yosys_docs"], + html_docs_path=["./data/html/yosys_docs"], + weights=[0.6, 0.2, 0.2], + contextual_rerank=True, + search_k=search_k, + chunk_size=chunk_size, + ) + yosys_rtdocs_retriever_chain.create_hybrid_retriever() + ORFS.yosys_rtdocs_retriever = yosys_rtdocs_retriever_chain.retriever + + klayout_retriever_chain = HybridRetrieverChain( + embeddings_config=embeddings_config, + reranking_model_name=reranking_model_name, + use_cuda=use_cuda, + # html_docs_path=fastmode_docs_map["klayout"] + # if fast_mode + # else ["./data/html/klayout_docs"], + html_docs_path=["./data/html/klayout_docs"], + weights=[0.6, 0.2, 0.2], + contextual_rerank=True, + search_k=search_k, + chunk_size=chunk_size, + ) + klayout_retriever_chain.create_hybrid_retriever() + ORFS.klayout_retriever = klayout_retriever_chain.retriever + + errinfo_retriever_chain = HybridRetrieverChain( + embeddings_config=embeddings_config, + reranking_model_name=reranking_model_name, + use_cuda=use_cuda, + # markdown_docs_path=fastmode_docs_map["errinfo"] + # if fast_mode + # else markdown_docs_map["errinfo"], + markdown_docs_path=markdown_docs_map["errinfo"], + weights=[0.6, 0.2, 0.2], + contextual_rerank=True, + search_k=search_k, + chunk_size=chunk_size, + ) + errinfo_retriever_chain.create_hybrid_retriever() + ORFS.errinfo_retriever = errinfo_retriever_chain.retriever + + # TODO: make callable as mcp tool? + @staticmethod + @ORFS.mcp.tool + def retrieve_general(query: str) -> Tuple[str, list[str], list[str], list[str]]: + """ + Retrieve comprehensive and detailed information pertaining to the OpenROAD project, OpenROAD-Flow-Scripts and OpenSTA.\ + This includes, but is not limited to, general information, specific functionalities, usage guidelines,\ + troubleshooting steps, and best practices. The tool is designed to assist users by providing clear, accurate,\ + and relevant information that enhances their understanding and efficient use of OpenROAD and OpenROAD-Flow-Scripts.\ + """ + if ORFS.general_retriever is None: + raise ValueError("General Retriever not initialized") + else: + docs = ORFS.general_retriever.invoke(input=query) + return format_docs(docs) + + @staticmethod + @ORFS.mcp.tool + def retrieve_cmds(query: str) -> Tuple[str, list[str], list[str], list[str]]: + """ + Retrieve information on the commands available in OpenROAD, OpenROAD-Flow-Scripts and OpenSTA.\ + This includes usage guidelines, command syntax, examples, and best practices about commands that cover various \ + aspects of electronic design automation, such as synthesis, placement, routing, analysis, and \ + optimization within the OpenROAD environment. + + OR and ORFS Commands: + Antenna Rule Checker (ANT), Clock Tree Synthesis (CTS), Design For Testing (DFT), Detailed Placement (DPL), \ + Detailed Routing (DRT), Metal Fill (FIN), Floorplanning, Global Placement (GPL), Global Routing (GRT), Graphical User Interface (GUI), \ + Initialize Floorplan (IFP), Macro Placement (MPL), Hierarchical Macro Placement (MPL2), OpenDB (ODB), Chip-level Connections (PAD),\ + Partition Manager (PAR), Power Distribution Network (PDN), Pin Placement (PPL), IR Drop Analysis (PSM), Parasitics Extraction (RSX),\ + Restructure (RMP), Gate Resizer (RSZ), Rectilinear Steiner Tree (STT), TapCell (TAP), Read Unified Power Format (UPF), Timing Optimization\ + + OpenSTA is an open-source gate-level static timing verifier.\ + It can verify the timing of deisgns in the form of Verilog netlists.\ + Timing Analysis: Perform static timing analysis using standard file formats (Verilog, Liberty, SDC, SDF, SPEF). \ + Multiple Process Corners: Conduct analysis across different process variations. \ + Power Analysis: Evaluate power consumption in designs. \ + TCL Interpreter: Use TCL scripts for command automation and customization. \ + """ + if ORFS.commands_retriever is None: + raise ValueError("Commands Retriever not initialized") + else: + docs = ORFS.commands_retriever.invoke(input=query) + return format_docs(docs) + + @staticmethod + @ORFS.mcp.tool + def retrieve_install(query: str) -> Tuple[str, list[str], list[str], list[str]]: + """ + Retrieve comprehensive and detailed information pertaining to the installaion of OpenROAD, OpenROAD-Flow-Scripts and OpenSTA.\ + This includes, but is not limited to, various dependencies, system requirements, installation methods such as,\ + - Building from source\ + - Using Docker\ + - Using pre-built binaries\ + """ + if ORFS.install_retriever is None: + raise ValueError("Install Retriever not initialized") + else: + docs = ORFS.install_retriever.invoke(input=query) + return format_docs(docs) + + @staticmethod + @ORFS.mcp.tool + def retrieve_errinfo(query: str) -> Tuple[str, list[str], list[str], list[str]]: + """ + Retrieve descriptions and details regarding the various warning/error messages encountered while using the OpenROAD.\ + An error code usually is identified by the tool, followed by a number.\ + Examples: ANT-0001, CTS-0014 etc.\ + """ + + if ORFS.errinfo_retriever is None: + raise ValueError("Error Info Retriever not initialized") + else: + docs = ORFS.errinfo_retriever.invoke(input=query) + return format_docs(docs) + + @staticmethod + @ORFS.mcp.tool + def retrieve_yosys_rtdocs( + query: str, + ) -> Tuple[str, list[str], list[str], list[str]]: + """ + Retrieve detailed information regarding the Yosys application.\ + This tool provides information pertaining to the installation, usage, and troubleshooting of Yosys.\ + + Yosys is a framework for Verilog RTL synthesis.\ + It currently has extensive Verilog-2005 support and provides a basic set of synthesis algorithms for various application domains.\ + Setup: Configure Yosys for synthesis tasks. + Usage: Execute synthesis commands and scripts. + Troubleshooting: Resolve common issues in synthesis flows. + """ + + if ORFS.yosys_rtdocs_retriever is None: + raise ValueError("Yosys RTDocs Retriever not initialized") + else: + docs = ORFS.yosys_rtdocs_retriever.invoke(input=query) + return format_docs(docs) + + @staticmethod + @ORFS.mcp.tool + def retrieve_klayout_docs( + query: str, + ) -> Tuple[str, list[str], list[str], list[str]]: + """ + Retrieve detailed information regarding the KLayout application.\ + This tool provides information pertaining to the installation, usage, and troubleshooting of KLayout.\ + + KLayout is a powerful open-source layout viewer and editor designed for integrated circuit (IC) design.\ + It supports various file formats, including GDSII, OASIS, and DXF + """ + + if ORFS.klayout_retriever is None: + raise ValueError("KLayout Retriever not initialized") + else: + docs = ORFS.klayout_retriever.invoke(input=query) + return format_docs(docs) diff --git a/backend/src/openroad_mcp/server/orfs/orfs_server.py b/backend/src/openroad_mcp/server/orfs/orfs_server.py index e006bd99..e7f226eb 100644 --- a/backend/src/openroad_mcp/server/orfs/orfs_server.py +++ b/backend/src/openroad_mcp/server/orfs/orfs_server.py @@ -1,183 +1,66 @@ import os -import subprocess import logging +from typing import TypedDict, Any from dotenv import load_dotenv -from fastmcp import FastMCP -from typing import Optional -from .orfs_tools import ORFSTools +from .pipeline import Synthesis, Floorplan, Placement, CTS, Routing, FinalReport +from src.openroad_mcp.server.orfs.orfs_tools import ORFS +from src.openroad_mcp.server.orfs.orfs_make import ORFSMake +from src.openroad_mcp.server.orfs.orfs_base import ORFSBase +from src.openroad_mcp.server.orfs.orfs_rag import ORFSRag -load_dotenv() -env = os.environ -orfs_dir: Optional[str] = os.getenv("ORFS_DIR") -if orfs_dir is None: - raise ValueError("ORFS_DIR environment variable is not set") -flow_dir = os.path.join(orfs_dir, "flow") -mcp = FastMCP("ORFS") +logging.basicConfig( + level=os.environ.get("LOGLEVEL", "INFO").upper(), + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s\n", +) -class ORFS(ORFSTools): - @staticmethod - def _get_platforms_impl() -> str: - """Internal implementation of get_platforms""" - # TODO: scrape platforms instead of serving only default sky130 - if False: - pass - else: - ORFS.platform = "sky130hd" - return ORFS.platform +class ORFSEnv(TypedDict): + general: list[str | None] - @staticmethod - def _get_designs_impl() -> str: - """Internal implementation of get_designs""" - # TODO: scrape designs instead of default riscv - if False: - pass - else: - ORFS.design = "riscv32i" - return ORFS.design + synthesis: list[str | None] + floorplan: list[str | None] + placement: list[str | None] + cts: list[str | None] + routing: list[str | None] - @mcp.tool - @staticmethod - def get_platforms() -> str: - """call get platforms to display possible platforms to run through flow""" - return ORFS._get_platforms_impl() - @mcp.tool - @staticmethod - def get_designs() -> str: - """call get designs to display possible designs to run through flow""" - return ORFS._get_designs_impl() +class ORFSServer(ORFSBase, ORFSMake, ORFSRag): + def __init__(self) -> None: + ORFS.server = self - @mcp.tool - @staticmethod - def make(cmd: str) -> str: - """call make command if query contains make keyword and a single argument""" - working = os.getcwd() - os.chdir(flow_dir) + self.design: str | None = None + self.platform: str | None = None + self.command: str | None = None + self.makefile_pointer: str | None = None - if not ORFS.platform: - platform = ORFS._get_platforms_impl() - logging.info(platform) + self.orfs_env: dict[str, str] = {} - if not ORFS.design: - design = ORFS._get_designs_impl() - logging.info(design) + self.design_list: list[str] = [] + self.stages: dict[int, Any] = { + 0: Synthesis(), + 1: Floorplan(), + 2: Placement(), + 3: CTS(), + 4: Routing(), + 5: FinalReport(), + } + self.stage_index: dict[str, int] = {v.info(): k for k, v in self.stages.items()} - make = f"make DESIGN_CONFIG={flow_dir}/designs/{ORFS.platform}/{ORFS.design}/config.mk" - build_command = make + " " + cmd - ORFS.run_command(build_command) + self.cur_stage: int = -1 - os.chdir(working) - return f"finished {cmd}" + self._setup_env() + logging.warning("instantiated...") - @mcp.tool - @staticmethod - def get_stage_names() -> str: - """get stage names for possible states this mcp server can be in the chip design pipeline""" - stage_names = [_.info() for _ in ORFS.stages.values()] - logging.info(stage_names) # in server process - # for chatbot output - result = "" - for _ in stage_names: - result += f"{_}\n" - return result - - @mcp.tool - @staticmethod - def jump(stage: str) -> str: - """call jump command if contains jump keyword and stage argument""" - working = os.getcwd() - os.chdir(flow_dir) - - if not ORFS.platform: - platform = ORFS._get_platforms_impl() - logging.info(platform) - - if not ORFS.design: - design = ORFS._get_designs_impl() - logging.info(design) - - make = f"make DESIGN_CONFIG={flow_dir}/designs/{ORFS.platform}/{ORFS.design}/config.mk" - stage_names = [_.info() for _ in ORFS.stages.values()] - logging.info(stage_names) - if stage in stage_names: - logging.info(stage) - build_command = make + " " + stage - ORFS.cur_stage = ORFS.stage_index[stage] - ORFS.run_command(build_command) - - build_gui_command = make + " gui_" + stage - ORFS.run_command(build_gui_command) - - os.chdir(working) - return f"finished {stage}" - else: - logging.info("jump unsuccessful...") - return f"aborted {stage}" - - @mcp.tool - @staticmethod - def step() -> str: - """call step command if contains step keyword to progress through pipeline""" - - def make_keyword(): - logging.info(ORFS.cur_stage) - if ORFS.cur_stage <= len(ORFS.stages) - 2: - ORFS.cur_stage += 1 - else: - logging.info("end of pipeline...") - return ORFS.stages[ORFS.cur_stage].info() - - working = os.getcwd() - os.chdir(flow_dir) - - if not ORFS.platform: - platform = ORFS._get_platforms_impl() - logging.info(platform) - - if not ORFS.design: - design = ORFS._get_designs_impl() - logging.info(design) - - make = f"make DESIGN_CONFIG={flow_dir}/designs/{ORFS.platform}/{ORFS.design}/config.mk" - command = make_keyword() - logging.info(command) - build_command = make + " " + command - ORFS.run_command(build_command) - - build_gui_command = make + " gui_" + command - ORFS.run_command(build_gui_command) - - os.chdir(working) - return f"finished {command}" - - @staticmethod - def run_command(cmd: str) -> None: - logging.info("start command") - - process = subprocess.Popen( - cmd.split(), - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=1, # Line-buffered - universal_newlines=True, # Text mode - env=env, - ) - - if process.stdout: - for line in process.stdout: - logging.info(line.rstrip()) - - process.wait() - if process.returncode != 0: - logging.error(f"Command exited with return code {process.returncode}") - raise subprocess.CalledProcessError(process.returncode, cmd) - - # TODO: scrape all makefile keywords and make into mcp tool - def get_all_keywords(self) -> None: - pass + def _setup_env(self) -> None: + load_dotenv() + self.env = os.environ + self.orfs_dir: str | None = os.getenv("ORFS_DIR") + if self.orfs_dir is None: + raise ValueError("ORFS_DIR environment variable is not set") + self.flow_dir = os.path.join(self.orfs_dir, "flow") if __name__ == "__main__": - mcp.run(transport="http", host="127.0.0.1", port=3001) + server = ORFSServer() + server.mcp.run(transport="http", host="127.0.0.1", port=3001) diff --git a/backend/src/openroad_mcp/server/orfs/orfs_tools.py b/backend/src/openroad_mcp/server/orfs/orfs_tools.py index acdbc9b1..b2c7707a 100644 --- a/backend/src/openroad_mcp/server/orfs/orfs_tools.py +++ b/backend/src/openroad_mcp/server/orfs/orfs_tools.py @@ -1,23 +1,15 @@ -from .pipeline import Synthesis, Floorplan, Placement, CTS, Routing, FinalReport -from typing import Any +from typing import Any, Optional +from fastmcp import FastMCP -# Global variables and functions for ORFS class -class ORFSTools: - design: str | None = None - platform: str | None = None - command: str | None = None +class ORFS: + mcp = FastMCP("ORFS") + server: Optional[Any] = None - design_list: list[str] = [] - - stages: dict[int, Any] = { - 0: Synthesis(), - 1: Floorplan(), - 2: Placement(), - 3: CTS(), - 4: Routing(), - 5: FinalReport(), - } - stage_index: dict[str, int] = {v.info(): k for k, v in stages.items()} - - cur_stage: int = -1 + llm: Optional[Any] = None + general_retriever: Optional[Any] = None + install_retriever: Optional[Any] = None + commands_retriever: Optional[Any] = None + yosys_rtdocs_retriever: Optional[Any] = None + klayout_retriever: Optional[Any] = None + errinfo_retriever: Optional[Any] = None diff --git a/backend/src/openroad_mcp/server/orfs/pipeline/cts.py b/backend/src/openroad_mcp/server/orfs/pipeline/cts.py index d62dabaf..4bdd97ca 100644 --- a/backend/src/openroad_mcp/server/orfs/pipeline/cts.py +++ b/backend/src/openroad_mcp/server/orfs/pipeline/cts.py @@ -1,5 +1,5 @@ class CTS: - def __init__(self): + def __init__(self) -> None: self.name = "cts_stage" def info(self) -> str: diff --git a/backend/src/openroad_mcp/server/orfs/pipeline/final_report.py b/backend/src/openroad_mcp/server/orfs/pipeline/final_report.py index 79f969c0..fac7ff05 100644 --- a/backend/src/openroad_mcp/server/orfs/pipeline/final_report.py +++ b/backend/src/openroad_mcp/server/orfs/pipeline/final_report.py @@ -1,5 +1,5 @@ class FinalReport: - def __init__(self): + def __init__(self) -> None: self.name = "final_report_stage" def info(self) -> str: diff --git a/backend/src/openroad_mcp/server/orfs/pipeline/floorplan.py b/backend/src/openroad_mcp/server/orfs/pipeline/floorplan.py index b4886934..e8ec7343 100644 --- a/backend/src/openroad_mcp/server/orfs/pipeline/floorplan.py +++ b/backend/src/openroad_mcp/server/orfs/pipeline/floorplan.py @@ -1,5 +1,5 @@ class Floorplan: - def __init__(self): + def __init__(self) -> None: self.name = "floorplan_stage" def info(self) -> str: diff --git a/backend/src/openroad_mcp/server/orfs/pipeline/placement.py b/backend/src/openroad_mcp/server/orfs/pipeline/placement.py index 4bb12394..1b008035 100644 --- a/backend/src/openroad_mcp/server/orfs/pipeline/placement.py +++ b/backend/src/openroad_mcp/server/orfs/pipeline/placement.py @@ -1,5 +1,5 @@ class Placement: - def __init__(self): + def __init__(self) -> None: self.name = "placement_stage" def info(self) -> str: diff --git a/backend/src/openroad_mcp/server/orfs/pipeline/routing.py b/backend/src/openroad_mcp/server/orfs/pipeline/routing.py index 9226b33d..cf721eab 100644 --- a/backend/src/openroad_mcp/server/orfs/pipeline/routing.py +++ b/backend/src/openroad_mcp/server/orfs/pipeline/routing.py @@ -1,5 +1,5 @@ class Routing: - def __init__(self): + def __init__(self) -> None: self.name = "routing_stage" def info(self) -> str: diff --git a/backend/src/openroad_mcp/server/orfs/pipeline/synthesis.py b/backend/src/openroad_mcp/server/orfs/pipeline/synthesis.py index 3f57306c..c2791ee7 100644 --- a/backend/src/openroad_mcp/server/orfs/pipeline/synthesis.py +++ b/backend/src/openroad_mcp/server/orfs/pipeline/synthesis.py @@ -1,5 +1,5 @@ class Synthesis: - def __init__(self): + def __init__(self) -> None: self.name = "synthesis_stage" def info(self) -> str: diff --git a/backend/src/prompts/prompt_templates.py b/backend/src/prompts/prompt_templates.py index a1d67989..6ba4c21b 100644 --- a/backend/src/prompts/prompt_templates.py +++ b/backend/src/prompts/prompt_templates.py @@ -161,3 +161,18 @@ Assistant Answer: {assistant_answer} """ + +env_prompt_template = """ +You are an assistant that always answers in JSON format. +Use the given context and the question to produce the answer. +Each key in the JSON should correspond to a variable name, and each value should be a string that is not a sentence but a parameter, number, or boolean that can be exported in the format: +export = + +Context: +{context} + +Question: +{question} + +Return ONLY valid JSON. No explanations, no text outside of the JSON. +""" diff --git a/backend/src/vectorstores/faiss.py b/backend/src/vectorstores/faiss.py index c531a523..f4308dd5 100644 --- a/backend/src/vectorstores/faiss.py +++ b/backend/src/vectorstores/faiss.py @@ -204,14 +204,14 @@ def get_db_path(self) -> str: path = os.path.abspath(path) # Ensure proper parent directory return path - def save_db(self, name) -> None: + def save_db(self, name: str) -> None: if self._faiss_db is None: raise ValueError("No documents in FAISS database") else: save_path = f"{self.get_db_path()}/{name}" self._faiss_db.save_local(save_path) - def load_db(self, name) -> None: + def load_db(self, name: str) -> None: load_path = f"{self.get_db_path()}/{name}" self._faiss_db = FAISS.load_local( load_path, self.embedding_model, allow_dangerous_deserialization=True diff --git a/backend/tests/test_orfs_base.py b/backend/tests/test_orfs_base.py new file mode 100644 index 00000000..37cc671b --- /dev/null +++ b/backend/tests/test_orfs_base.py @@ -0,0 +1,349 @@ +import pytest +import os +import subprocess +from unittest.mock import Mock, patch +from src.openroad_mcp.server.orfs.orfs_base import ORFSBase, _should_skip_gui +from src.openroad_mcp.server.orfs.orfs_tools import ORFS + + +class TestORFSBase: + """Test suite for ORFSBase class.""" + + @pytest.fixture + def mock_orfs_server(self, tmp_path): + """Create a mock ORFS server instance.""" + mock_server = Mock() + mock_server.platform = None + mock_server.design = None + mock_server.flow_dir = str(tmp_path / "flow") + mock_server.env = os.environ.copy() + mock_server.cur_stage = 0 + + # Create mock stages + mock_server.stages = { + 0: Mock(info=lambda: "synth"), + 1: Mock(info=lambda: "floorplan"), + 2: Mock(info=lambda: "place"), + 3: Mock(info=lambda: "cts"), + 4: Mock(info=lambda: "route"), + 5: Mock(info=lambda: "final"), + } + mock_server.stage_index = { + "synth": 0, + "floorplan": 1, + "place": 2, + "cts": 3, + "route": 4, + "final": 5, + } + + # Create flow directory + flow_dir = tmp_path / "flow" + flow_dir.mkdir(parents=True, exist_ok=True) + + return mock_server + + def test_should_skip_gui_true(self): + """Test _should_skip_gui returns True when DISABLE_GUI is set.""" + with patch.dict(os.environ, {"DISABLE_GUI": "true"}): + assert _should_skip_gui() is True + + with patch.dict(os.environ, {"DISABLE_GUI": "1"}): + assert _should_skip_gui() is True + + with patch.dict(os.environ, {"DISABLE_GUI": "yes"}): + assert _should_skip_gui() is True + + def test_should_skip_gui_false(self): + """Test _should_skip_gui returns False when DISABLE_GUI is not set or false.""" + with patch.dict(os.environ, {"DISABLE_GUI": "false"}): + assert _should_skip_gui() is False + + with patch.dict(os.environ, {}, clear=True): + # Default should be false + assert _should_skip_gui() is False + + def test_get_platforms_impl(self, mock_orfs_server): + """Test _get_platforms_impl sets default platform.""" + ORFS.server = mock_orfs_server + base = ORFSBase() + + result = base._get_platforms_impl() + + assert result == "sky130hd" + assert mock_orfs_server.platform == "sky130hd" + + def test_get_designs_impl(self, mock_orfs_server): + """Test _get_designs_impl sets default design.""" + ORFS.server = mock_orfs_server + base = ORFSBase() + + result = base._get_designs_impl() + + assert result == "riscv32i" + assert mock_orfs_server.design == "riscv32i" + + def test_check_configuration_initializes_platform(self, mock_orfs_server): + """Test _check_configuration initializes platform if not set.""" + ORFS.server = mock_orfs_server + mock_orfs_server.platform = None + mock_orfs_server.design = "test_design" + mock_orfs_server._get_platforms_impl = Mock(return_value="sky130hd") + mock_orfs_server._get_designs_impl = Mock() + + base = ORFSBase() + base._check_configuration() + + mock_orfs_server._get_platforms_impl.assert_called_once() + mock_orfs_server._get_designs_impl.assert_not_called() + + def test_check_configuration_initializes_design(self, mock_orfs_server): + """Test _check_configuration initializes design if not set.""" + ORFS.server = mock_orfs_server + mock_orfs_server.platform = "sky130hd" + mock_orfs_server.design = None + mock_orfs_server._get_platforms_impl = Mock() + mock_orfs_server._get_designs_impl = Mock(return_value="riscv32i") + + base = ORFSBase() + base._check_configuration() + + mock_orfs_server._get_platforms_impl.assert_not_called() + mock_orfs_server._get_designs_impl.assert_called_once() + + def test_check_configuration_already_set(self, mock_orfs_server): + """Test _check_configuration does nothing when already configured.""" + ORFS.server = mock_orfs_server + mock_orfs_server.platform = "sky130hd" + mock_orfs_server.design = "riscv32i" + mock_orfs_server._get_platforms_impl = Mock() + mock_orfs_server._get_designs_impl = Mock() + + base = ORFSBase() + base._check_configuration() + + mock_orfs_server._get_platforms_impl.assert_not_called() + mock_orfs_server._get_designs_impl.assert_not_called() + + @patch("subprocess.Popen") + def test_run_command_success(self, mock_popen, mock_orfs_server): + """Test _run_command executes successfully.""" + ORFS.server = mock_orfs_server + + # Mock subprocess + mock_process = Mock() + mock_process.returncode = 0 + mock_process.stdout = iter(["Line 1\n", "Line 2\n", "Complete\n"]) + mock_process.wait = Mock(return_value=0) + mock_popen.return_value = mock_process + + base = ORFSBase() + base._run_command("make synth") + + # Verify subprocess was called + mock_popen.assert_called_once() + assert "['make', 'synth']" in str(mock_popen.call_args) + + @patch("subprocess.Popen") + def test_run_command_failure(self, mock_popen, mock_orfs_server): + """Test _run_command raises exception on failure.""" + ORFS.server = mock_orfs_server + + # Mock subprocess failure + mock_process = Mock() + mock_process.returncode = 1 + mock_process.stdout = iter(["Error occurred\n"]) + mock_process.wait = Mock(return_value=1) + mock_popen.return_value = mock_process + + base = ORFSBase() + + with pytest.raises(subprocess.CalledProcessError): + base._run_command("make synth") + + @patch("subprocess.Popen") + def test_command_changes_directory(self, mock_popen, mock_orfs_server, tmp_path): + """Test _command changes to flow directory and back.""" + ORFS.server = mock_orfs_server + mock_orfs_server.platform = "sky130hd" + mock_orfs_server.design = "riscv32i" + mock_orfs_server._run_command = Mock() + + original_dir = os.getcwd() + + base = ORFSBase() + base._command("synth") + + # Verify we're back in original directory + assert os.getcwd() == original_dir + + # Verify _run_command was called + mock_orfs_server._run_command.assert_called_once() + + def test_get_platforms_tool(self, mock_orfs_server): + """Test get_platforms MCP tool.""" + ORFS.server = mock_orfs_server + + def set_platform(): + mock_orfs_server.platform = "sky130hd" + return "sky130hd" + + mock_orfs_server._get_platforms_impl = Mock(side_effect=set_platform) + + result = ORFSBase.get_platforms.fn() + + assert result == "sky130hd" + assert mock_orfs_server.platform == "sky130hd" + + def test_get_designs_tool(self, mock_orfs_server): + """Test get_designs MCP tool.""" + ORFS.server = mock_orfs_server + + def set_design(): + mock_orfs_server.design = "riscv32i" + return "riscv32i" + + mock_orfs_server._get_designs_impl = Mock(side_effect=set_design) + + result = ORFSBase.get_designs.fn() + + assert result == "riscv32i" + assert mock_orfs_server.design == "riscv32i" + + def test_make_tool(self, mock_orfs_server): + """Test make MCP tool.""" + ORFS.server = mock_orfs_server + mock_orfs_server._check_configuration = Mock() + mock_orfs_server._command = Mock() + + result = ORFSBase.make.fn("clean") + + mock_orfs_server._check_configuration.assert_called_once() + mock_orfs_server._command.assert_called_once_with("clean") + assert result == "finished clean" + + def test_get_stage_names(self, mock_orfs_server): + """Test get_stage_names MCP tool.""" + ORFS.server = mock_orfs_server + + result = ORFSBase.get_stage_names.fn() + + assert "synth" in result + assert "floorplan" in result + assert "place" in result + assert "cts" in result + assert "route" in result + assert "final" in result + + def test_jump_valid_stage(self, mock_orfs_server): + """Test jump to valid stage.""" + ORFS.server = mock_orfs_server + mock_orfs_server._check_configuration = Mock() + mock_orfs_server._command = Mock() + + with patch.dict(os.environ, {"DISABLE_GUI": "true"}): + result = ORFSBase.jump.fn("floorplan") + + assert result == "finished floorplan" + assert mock_orfs_server.cur_stage == 1 + mock_orfs_server._command.assert_called_once_with("floorplan") + + def test_jump_invalid_stage(self, mock_orfs_server): + """Test jump to invalid stage.""" + ORFS.server = mock_orfs_server + mock_orfs_server._check_configuration = Mock() + mock_orfs_server._command = Mock() + + result = ORFSBase.jump.fn("invalid_stage") + + assert result == "aborted invalid_stage" + mock_orfs_server._command.assert_not_called() + + def test_jump_with_gui_enabled(self, mock_orfs_server): + """Test jump attempts to open GUI when not disabled.""" + ORFS.server = mock_orfs_server + mock_orfs_server._check_configuration = Mock() + mock_orfs_server._command = Mock() + + with patch.dict(os.environ, {"DISABLE_GUI": "false"}): + _result = ORFSBase.jump.fn("synth") + + # Should call both synth and gui_synth + assert mock_orfs_server._command.call_count == 2 + mock_orfs_server._command.assert_any_call("synth") + mock_orfs_server._command.assert_any_call("gui_synth") + + def test_jump_gui_failure_doesnt_crash(self, mock_orfs_server): + """Test jump handles GUI command failure gracefully.""" + ORFS.server = mock_orfs_server + mock_orfs_server._check_configuration = Mock() + + # Make GUI command fail + def command_side_effect(cmd): + if cmd.startswith("gui_"): + raise subprocess.CalledProcessError(1, cmd) + + mock_orfs_server._command = Mock(side_effect=command_side_effect) + + with patch.dict(os.environ, {"DISABLE_GUI": "false"}): + result = ORFSBase.jump.fn("synth") + + # Should still return success even though GUI failed + assert result == "finished synth" + + def test_step_advances_stage(self, mock_orfs_server): + """Test step advances to next stage.""" + ORFS.server = mock_orfs_server + mock_orfs_server.cur_stage = 0 + mock_orfs_server._check_configuration = Mock() + mock_orfs_server._command = Mock() + + with patch.dict(os.environ, {"DISABLE_GUI": "true"}): + result = ORFSBase.step.fn() + + assert mock_orfs_server.cur_stage == 1 + assert result == "finished floorplan" + mock_orfs_server._command.assert_called_once_with("floorplan") + + def test_step_at_end_of_pipeline(self, mock_orfs_server): + """Test step at end of pipeline doesn't advance beyond final stage.""" + ORFS.server = mock_orfs_server + mock_orfs_server.cur_stage = 5 # Final stage + mock_orfs_server._check_configuration = Mock() + mock_orfs_server._command = Mock() + + with patch.dict(os.environ, {"DISABLE_GUI": "true"}): + result = ORFSBase.step.fn() + + # Should stay at stage 5 + assert mock_orfs_server.cur_stage == 5 + assert result == "finished final" + + def test_step_with_gui_enabled(self, mock_orfs_server): + """Test step attempts to open GUI when not disabled.""" + ORFS.server = mock_orfs_server + mock_orfs_server.cur_stage = 0 + mock_orfs_server._check_configuration = Mock() + mock_orfs_server._command = Mock() + + with patch.dict(os.environ, {"DISABLE_GUI": "false"}): + _result = ORFSBase.step.fn() + + # Should call both floorplan and gui_floorplan + assert mock_orfs_server._command.call_count == 2 + mock_orfs_server._command.assert_any_call("floorplan") + mock_orfs_server._command.assert_any_call("gui_floorplan") + + def test_get_platforms_requires_server(self): + """Test get_platforms fails without server initialization.""" + ORFS.server = None + + with pytest.raises(AssertionError): + ORFSBase.get_platforms.fn() + + def test_make_requires_server(self): + """Test make fails without server initialization.""" + ORFS.server = None + + with pytest.raises(AssertionError): + ORFSBase.make.fn("synth") diff --git a/backend/tests/test_orfs_make.py b/backend/tests/test_orfs_make.py new file mode 100644 index 00000000..d4c3b256 --- /dev/null +++ b/backend/tests/test_orfs_make.py @@ -0,0 +1,219 @@ +import pytest +import os +from unittest.mock import Mock, patch +from src.openroad_mcp.server.orfs.orfs_make import ORFSMake +from src.openroad_mcp.server.orfs.orfs_tools import ORFS + + +class TestORFSMake: + """Test suite for ORFSMake class.""" + + @pytest.fixture + def mock_orfs_server(self, tmp_path): + """Create a mock ORFS server instance.""" + mock_server = Mock() + mock_server.platform = "sky130hd" + mock_server.design = "riscv32i" + mock_server.flow_dir = str(tmp_path / "flow") + mock_server.orfs_env = {} + mock_server.makefile_pointer = None + mock_server.dynamic_makefile = False + + # Create necessary directory structure + design_dir = tmp_path / "flow" / "designs" / "sky130hd" / "riscv32i" + design_dir.mkdir(parents=True, exist_ok=True) + + return mock_server + + def test_get_default_makefile(self, mock_orfs_server): + """Test _get_default_makefile sets correct path.""" + ORFS.server = mock_orfs_server + make = ORFSMake() + + make._get_default_makefile() + + expected_path = f"{mock_orfs_server.flow_dir}/designs/{mock_orfs_server.platform}/{mock_orfs_server.design}/config.mk" + assert mock_orfs_server.makefile_pointer == expected_path + + def test_get_makefile(self, mock_orfs_server): + """Test _get_makefile returns current pointer.""" + ORFS.server = mock_orfs_server + mock_orfs_server.makefile_pointer = "/path/to/config.mk" + + make = ORFSMake() + result = make._get_makefile() + + assert result == "/path/to/config.mk" + + def test_get_default_env(self, mock_orfs_server): + """Test _get_default_env initializes standard variables.""" + ORFS.server = mock_orfs_server + make = ORFSMake() + + make._get_default_env() + + # Check required variables are set + assert "PLATFORM" in mock_orfs_server.orfs_env + assert "DESIGN_NAME" in mock_orfs_server.orfs_env + assert "DESIGN_NICKNAME" in mock_orfs_server.orfs_env + assert "VERILOG_FILES" in mock_orfs_server.orfs_env + assert "SDC_FILE" in mock_orfs_server.orfs_env + assert "CORE_UTILIZATION" in mock_orfs_server.orfs_env + assert "PLACE_DENSITY" in mock_orfs_server.orfs_env + + # Check values match server config + assert mock_orfs_server.orfs_env["PLATFORM"] == "sky130hd" + assert mock_orfs_server.orfs_env["DESIGN_NAME"] == "riscv32i" + assert mock_orfs_server.orfs_env["CORE_UTILIZATION"] == "50" + + def test_create_dynamic_makefile_no_env_vars(self, mock_orfs_server): + """Test create_dynamic_makefile returns 'no env vars' when environment is empty.""" + ORFS.server = mock_orfs_server + mock_orfs_server.orfs_env = {} + + # Access the underlying function from the FunctionTool + result = ORFSMake.create_dynamic_makefile.fn("test") + + assert result == "no env vars" + + def test_create_dynamic_makefile_with_env_vars(self, mock_orfs_server, tmp_path): + """Test create_dynamic_makefile writes file with environment variables.""" + ORFS.server = mock_orfs_server + mock_orfs_server.orfs_env = { + "PLATFORM": "sky130hd", + "DESIGN_NAME": "riscv32i", + "CORE_UTILIZATION": "50", + } + + result = ORFSMake.create_dynamic_makefile.fn("test") + + # Check dynamic_makefile flag is set + assert mock_orfs_server.dynamic_makefile is True + + # Check makefile_pointer is set + expected_path = f"{mock_orfs_server.flow_dir}/designs/{mock_orfs_server.platform}/{mock_orfs_server.design}/dynamic_config.mk" + assert mock_orfs_server.makefile_pointer == expected_path + + # Check file was created + assert os.path.exists(expected_path) + + # Check file contents + with open(expected_path, "r") as f: + content = f.read() + assert "export PLATFORM = sky130hd" in content + assert "export DESIGN_NAME = riscv32i" in content + assert "export CORE_UTILIZATION = 50" in content + + # Check result format + assert "PLATFORM: sky130hd" in result + assert "DESIGN_NAME: riscv32i" in result + assert "CORE_UTILIZATION: 50" in result + + def test_create_dynamic_makefile_initializes_design_platform( + self, mock_orfs_server + ): + """Test create_dynamic_makefile initializes design/platform if not set.""" + ORFS.server = mock_orfs_server + mock_orfs_server.platform = None + mock_orfs_server.design = None + mock_orfs_server.orfs_env = {"TEST": "value"} + + # Mock the initialization methods + mock_orfs_server._get_designs_impl = Mock() + mock_orfs_server._get_platforms_impl = Mock() + mock_orfs_server._get_default_env = Mock() + + # Set values after init calls + def set_platform(): + mock_orfs_server.platform = "sky130hd" + + def set_design(): + mock_orfs_server.design = "riscv32i" + + mock_orfs_server._get_platforms_impl.side_effect = set_platform + mock_orfs_server._get_designs_impl.side_effect = set_design + + _result = ORFSMake.create_dynamic_makefile.fn("test") + + # Verify initialization was called + mock_orfs_server._get_designs_impl.assert_called_once() + mock_orfs_server._get_platforms_impl.assert_called_once() + mock_orfs_server._get_default_env.assert_called_once() + + @patch("src.openroad_mcp.server.orfs.orfs_make.JsonOutputParser") + @patch("src.openroad_mcp.server.orfs.orfs_make.ChatPromptTemplate") + def test_get_env_vars_success( + self, mock_prompt_template, mock_json_parser, mock_orfs_server + ): + """Test get_env_vars extracts variables from LLM.""" + ORFS.server = mock_orfs_server + ORFS.llm = Mock() + + # Mock initialization methods + mock_orfs_server._get_designs_impl = Mock() + mock_orfs_server._get_platforms_impl = Mock() + mock_orfs_server._get_default_env = Mock() + mock_orfs_server.retrieve_general = Mock( + return_value=["documentation context", []] + ) + + # Mock the LLM chain + mock_chain = Mock() + mock_chain.invoke.return_value = { + "CLOCK_PERIOD": "10.0", + "CORE_UTILIZATION": "60", + } + + # Setup chain building + mock_prompt_template.from_template.return_value.__or__ = Mock( + return_value=Mock(__or__=Mock(return_value=mock_chain)) + ) + + result = ORFSMake.get_env_vars.fn("get clock and utilization variables") + + # Check return value + assert result == "done env" + + # Verify environment was updated + assert "CLOCK_PERIOD" in mock_orfs_server.orfs_env + assert "CORE_UTILIZATION" in mock_orfs_server.orfs_env + assert mock_orfs_server.orfs_env["CLOCK_PERIOD"] == "10.0" + assert mock_orfs_server.orfs_env["CORE_UTILIZATION"] == "60" + + def test_get_env_vars_requires_server(self): + """Test get_env_vars fails without server initialization.""" + ORFS.server = None + + with pytest.raises(AssertionError): + ORFSMake.get_env_vars.fn("test") + + def test_get_env_vars_requires_llm(self, mock_orfs_server): + """Test get_env_vars fails without LLM initialization.""" + ORFS.server = mock_orfs_server + ORFS.llm = None + + with pytest.raises(AssertionError): + ORFSMake.get_env_vars.fn("test") + + def test_create_dynamic_makefile_with_makefile_syntax( + self, mock_orfs_server, tmp_path + ): + """Test that create_dynamic_makefile preserves Makefile syntax like $(shell ...).""" + ORFS.server = mock_orfs_server + mock_orfs_server.orfs_env = { + "PLATFORM": "asap7", + "CORE_AREA": "$(shell export VAR=1 && echo 100)", # Legitimate Makefile syntax + "DIE_AREA": "$(PLATFORM_DIR)/$(DESIGN_NAME)", # Variable reference + } + + _result = ORFSMake.create_dynamic_makefile.fn("test") + + # Check file was created + expected_path = f"{mock_orfs_server.flow_dir}/designs/{mock_orfs_server.platform}/{mock_orfs_server.design}/dynamic_config.mk" + assert os.path.exists(expected_path) + + # Verify Makefile syntax is preserved + with open(expected_path, "r") as f: + content = f.read() + assert "export CORE_AREA = $(shell export VAR=1 && echo 100)" in content + assert "export DIE_AREA = $(PLATFORM_DIR)/$(DESIGN_NAME)" in content diff --git a/backend/tests/test_orfs_rag.py b/backend/tests/test_orfs_rag.py new file mode 100644 index 00000000..49cdd77b --- /dev/null +++ b/backend/tests/test_orfs_rag.py @@ -0,0 +1,237 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import sys + +# Mock HybridRetrieverChain before importing orfs_rag to prevent model downloads during import +with patch("src.chains.hybrid_retriever_chain.HybridRetrieverChain") as mock_chain_class: + mock_instance = MagicMock() + mock_instance.create_hybrid_retriever = MagicMock() + mock_instance.retriever = MagicMock() + mock_chain_class.return_value = mock_instance + + # Now it's safe to import these modules + from src.openroad_mcp.server.orfs.orfs_tools import ORFS + # Import orfs_rag with mocked HybridRetrieverChain + if "src.openroad_mcp.server.orfs.orfs_rag" in sys.modules: + del sys.modules["src.openroad_mcp.server.orfs.orfs_rag"] + from src.openroad_mcp.server.orfs.orfs_rag import ORFSRag + + +class TestORFSRag: + """Test suite for ORFSRag retrieval methods.""" + + @pytest.fixture(autouse=True) + def setup_orfs(self): + """Setup ORFS class attributes before each test.""" + # Save original values + self.original_general = ORFS.general_retriever + self.original_commands = ORFS.commands_retriever + self.original_install = ORFS.install_retriever + self.original_errinfo = ORFS.errinfo_retriever + self.original_yosys = ORFS.yosys_rtdocs_retriever + self.original_klayout = ORFS.klayout_retriever + + yield + + # Restore original values + ORFS.general_retriever = self.original_general + ORFS.commands_retriever = self.original_commands + ORFS.install_retriever = self.original_install + ORFS.errinfo_retriever = self.original_errinfo + ORFS.yosys_rtdocs_retriever = self.original_yosys + ORFS.klayout_retriever = self.original_klayout + + @patch("src.openroad_mcp.server.orfs.orfs_rag.format_docs") + def test_retrieve_general_success(self, mock_format_docs): + """Test retrieve_general returns formatted docs.""" + # Mock retriever + mock_retriever = Mock() + mock_docs = [Mock(), Mock()] + mock_retriever.invoke.return_value = mock_docs + ORFS.general_retriever = mock_retriever + + # Mock format_docs + mock_format_docs.return_value = ( + "formatted text", + ["source1"], + ["title1"], + ["link1"], + ) + + result = ORFSRag.retrieve_general.fn("test query") + + # Verify retriever was called + mock_retriever.invoke.assert_called_once_with(input="test query") + + # Verify format_docs was called with docs + mock_format_docs.assert_called_once_with(mock_docs) + + # Verify result + assert result == ("formatted text", ["source1"], ["title1"], ["link1"]) + + def test_retrieve_general_not_initialized(self): + """Test retrieve_general raises error when retriever not initialized.""" + ORFS.general_retriever = None + + with pytest.raises(ValueError, match="General Retriever not initialized"): + ORFSRag.retrieve_general.fn("test query") + + @patch("src.openroad_mcp.server.orfs.orfs_rag.format_docs") + def test_retrieve_cmds_success(self, mock_format_docs): + """Test retrieve_cmds returns formatted docs.""" + # Mock retriever + mock_retriever = Mock() + mock_docs = [Mock(), Mock()] + mock_retriever.invoke.return_value = mock_docs + ORFS.commands_retriever = mock_retriever + + # Mock format_docs + mock_format_docs.return_value = ( + "formatted text", + ["source1"], + ["title1"], + ["link1"], + ) + + result = ORFSRag.retrieve_cmds.fn("make command") + + # Verify retriever was called + mock_retriever.invoke.assert_called_once_with(input="make command") + + # Verify result + assert result == ("formatted text", ["source1"], ["title1"], ["link1"]) + + def test_retrieve_cmds_not_initialized(self): + """Test retrieve_cmds raises error when retriever not initialized.""" + ORFS.commands_retriever = None + + with pytest.raises(ValueError, match="Commands Retriever not initialized"): + ORFSRag.retrieve_cmds.fn("test query") + + @patch("src.openroad_mcp.server.orfs.orfs_rag.format_docs") + def test_retrieve_install_success(self, mock_format_docs): + """Test retrieve_install returns formatted docs.""" + # Mock retriever + mock_retriever = Mock() + mock_docs = [Mock()] + mock_retriever.invoke.return_value = mock_docs + ORFS.install_retriever = mock_retriever + + # Mock format_docs + mock_format_docs.return_value = ( + "formatted text", + ["source1"], + ["title1"], + ["link1"], + ) + + result = ORFSRag.retrieve_install.fn("installation steps") + + # Verify retriever was called + mock_retriever.invoke.assert_called_once_with(input="installation steps") + + # Verify result + assert result == ("formatted text", ["source1"], ["title1"], ["link1"]) + + def test_retrieve_install_not_initialized(self): + """Test retrieve_install raises error when retriever not initialized.""" + ORFS.install_retriever = None + + with pytest.raises(ValueError, match="Install Retriever not initialized"): + ORFSRag.retrieve_install.fn("test query") + + @patch("src.openroad_mcp.server.orfs.orfs_rag.format_docs") + def test_retrieve_errinfo_success(self, mock_format_docs): + """Test retrieve_errinfo returns formatted docs.""" + # Mock retriever + mock_retriever = Mock() + mock_docs = [Mock()] + mock_retriever.invoke.return_value = mock_docs + ORFS.errinfo_retriever = mock_retriever + + # Mock format_docs + mock_format_docs.return_value = ( + "formatted text", + ["source1"], + ["title1"], + ["link1"], + ) + + result = ORFSRag.retrieve_errinfo.fn("ANT-0001") + + # Verify retriever was called + mock_retriever.invoke.assert_called_once_with(input="ANT-0001") + + # Verify result + assert result == ("formatted text", ["source1"], ["title1"], ["link1"]) + + def test_retrieve_errinfo_not_initialized(self): + """Test retrieve_errinfo raises error when retriever not initialized.""" + ORFS.errinfo_retriever = None + + with pytest.raises(ValueError, match="Error Info Retriever not initialized"): + ORFSRag.retrieve_errinfo.fn("test query") + + @patch("src.openroad_mcp.server.orfs.orfs_rag.format_docs") + def test_retrieve_yosys_rtdocs_success(self, mock_format_docs): + """Test retrieve_yosys_rtdocs returns formatted docs.""" + # Mock retriever + mock_retriever = Mock() + mock_docs = [Mock()] + mock_retriever.invoke.return_value = mock_docs + ORFS.yosys_rtdocs_retriever = mock_retriever + + # Mock format_docs + mock_format_docs.return_value = ( + "formatted text", + ["source1"], + ["title1"], + ["link1"], + ) + + result = ORFSRag.retrieve_yosys_rtdocs.fn("yosys synthesis") + + # Verify retriever was called + mock_retriever.invoke.assert_called_once_with(input="yosys synthesis") + + # Verify result + assert result == ("formatted text", ["source1"], ["title1"], ["link1"]) + + def test_retrieve_yosys_rtdocs_not_initialized(self): + """Test retrieve_yosys_rtdocs raises error when retriever not initialized.""" + ORFS.yosys_rtdocs_retriever = None + + with pytest.raises(ValueError, match="Yosys RTDocs Retriever not initialized"): + ORFSRag.retrieve_yosys_rtdocs.fn("test query") + + @patch("src.openroad_mcp.server.orfs.orfs_rag.format_docs") + def test_retrieve_klayout_docs_success(self, mock_format_docs): + """Test retrieve_klayout_docs returns formatted docs.""" + # Mock retriever + mock_retriever = Mock() + mock_docs = [Mock()] + mock_retriever.invoke.return_value = mock_docs + ORFS.klayout_retriever = mock_retriever + + # Mock format_docs + mock_format_docs.return_value = ( + "formatted text", + ["source1"], + ["title1"], + ["link1"], + ) + + result = ORFSRag.retrieve_klayout_docs.fn("klayout usage") + + # Verify retriever was called + mock_retriever.invoke.assert_called_once_with(input="klayout usage") + + # Verify result + assert result == ("formatted text", ["source1"], ["title1"], ["link1"]) + + def test_retrieve_klayout_docs_not_initialized(self): + """Test retrieve_klayout_docs raises error when retriever not initialized.""" + ORFS.klayout_retriever = None + + with pytest.raises(ValueError, match="KLayout Retriever not initialized"): + ORFSRag.retrieve_klayout_docs.fn("test query") diff --git a/backend/tests/test_retriever_tools.py b/backend/tests/test_retriever_tools.py index 584ce378..f6bd25d0 100644 --- a/backend/tests/test_retriever_tools.py +++ b/backend/tests/test_retriever_tools.py @@ -107,7 +107,7 @@ def test_retrieve_general_success(self, mock_format_docs): formatted_result = ("formatted_text", ["source1"], ["url1"], ["context1"]) mock_format_docs.return_value = formatted_result - result = RetrieverTools.retrieve_general("test query") + result = RetrieverTools.retrieve_general.invoke(input="test query") assert result == formatted_result mock_retriever.invoke.assert_called_once_with(input="test query") @@ -118,7 +118,7 @@ def test_retrieve_general_not_initialized(self): RetrieverTools.general_retriever = None with pytest.raises(ValueError, match="General Retriever not initialized"): - RetrieverTools.retrieve_general("test query") + RetrieverTools.retrieve_general.invoke(input="test query") @patch("src.agents.retriever_tools.format_docs") def test_retrieve_cmds_success(self, mock_format_docs): @@ -133,7 +133,7 @@ def test_retrieve_cmds_success(self, mock_format_docs): formatted_result = ("formatted_text", ["source1"], ["url1"], ["context1"]) mock_format_docs.return_value = formatted_result - result = RetrieverTools.retrieve_cmds("test query") + result = RetrieverTools.retrieve_cmds.invoke(input="test query") assert result == formatted_result mock_retriever.invoke.assert_called_once_with(input="test query") @@ -144,7 +144,7 @@ def test_retrieve_cmds_not_initialized(self): RetrieverTools.commands_retriever = None with pytest.raises(ValueError, match="Commands Retriever not initialized"): - RetrieverTools.retrieve_cmds("test query") + RetrieverTools.retrieve_cmds.invoke(input="test query") @patch("src.agents.retriever_tools.format_docs") def test_retrieve_install_success(self, mock_format_docs): @@ -159,7 +159,7 @@ def test_retrieve_install_success(self, mock_format_docs): formatted_result = ("formatted_text", ["source1"], ["url1"], ["context1"]) mock_format_docs.return_value = formatted_result - result = RetrieverTools.retrieve_install("test query") + result = RetrieverTools.retrieve_install.invoke(input="test query") assert result == formatted_result mock_retriever.invoke.assert_called_once_with(input="test query") @@ -170,7 +170,7 @@ def test_retrieve_install_not_initialized(self): RetrieverTools.install_retriever = None with pytest.raises(ValueError, match="Install Retriever not initialized"): - RetrieverTools.retrieve_install("test query") + RetrieverTools.retrieve_install.invoke(input="test query") @patch("src.agents.retriever_tools.format_docs") def test_retrieve_errinfo_success(self, mock_format_docs): @@ -185,7 +185,7 @@ def test_retrieve_errinfo_success(self, mock_format_docs): formatted_result = ("formatted_text", ["source1"], ["url1"], ["context1"]) mock_format_docs.return_value = formatted_result - result = RetrieverTools.retrieve_errinfo("test query") + result = RetrieverTools.retrieve_errinfo.invoke(input="test query") assert result == formatted_result mock_retriever.invoke.assert_called_once_with(input="test query") @@ -196,7 +196,7 @@ def test_retrieve_errinfo_not_initialized(self): RetrieverTools.errinfo_retriever = None with pytest.raises(ValueError, match="Error Info Retriever not initialized"): - RetrieverTools.retrieve_errinfo("test query") + RetrieverTools.retrieve_errinfo.invoke(input="test query") @patch("src.agents.retriever_tools.format_docs") def test_retrieve_yosys_rtdocs_success(self, mock_format_docs): @@ -211,7 +211,7 @@ def test_retrieve_yosys_rtdocs_success(self, mock_format_docs): formatted_result = ("formatted_text", ["source1"], ["url1"], ["context1"]) mock_format_docs.return_value = formatted_result - result = RetrieverTools.retrieve_yosys_rtdocs("test query") + result = RetrieverTools.retrieve_yosys_rtdocs.invoke(input="test query") assert result == formatted_result mock_retriever.invoke.assert_called_once_with(input="test query") @@ -222,7 +222,7 @@ def test_retrieve_yosys_rtdocs_not_initialized(self): RetrieverTools.yosys_rtdocs_retriever = None with pytest.raises(ValueError, match="Yosys RTDocs Retriever not initialized"): - RetrieverTools.retrieve_yosys_rtdocs("test query") + RetrieverTools.retrieve_yosys_rtdocs.invoke(input="test query") @patch("src.agents.retriever_tools.format_docs") def test_retrieve_klayout_docs_success(self, mock_format_docs): @@ -237,7 +237,7 @@ def test_retrieve_klayout_docs_success(self, mock_format_docs): formatted_result = ("formatted_text", ["source1"], ["url1"], ["context1"]) mock_format_docs.return_value = formatted_result - result = RetrieverTools.retrieve_klayout_docs("test query") + result = RetrieverTools.retrieve_klayout_docs.invoke(input="test query") assert result == formatted_result mock_retriever.invoke.assert_called_once_with(input="test query") @@ -248,7 +248,7 @@ def test_retrieve_klayout_docs_not_initialized(self): RetrieverTools.klayout_retriever = None with pytest.raises(ValueError, match="KLayout Retriever not initialized"): - RetrieverTools.retrieve_klayout_docs("test query") + RetrieverTools.retrieve_klayout_docs.invoke(input="test query") @patch("src.agents.retriever_tools.HybridRetrieverChain") def test_initialize_verifies_configuration_parameters(self, mock_hybrid_chain): @@ -423,7 +423,7 @@ def test_staticmethod_decorators(self): mock_format.return_value = ("", [], [], []) # Should be able to call without creating instance - result = RetrieverTools.retrieve_general("test") + result = RetrieverTools.retrieve_general.invoke(input="test") assert result == ("", [], [], []) @patch("src.agents.retriever_tools.HybridRetrieverChain")