diff --git a/scripts/doc-loader.py b/scripts/doc_loader.py similarity index 79% rename from scripts/doc-loader.py rename to scripts/doc_loader.py index ff7a2db..25318a4 100644 --- a/scripts/doc-loader.py +++ b/scripts/doc_loader.py @@ -6,6 +6,7 @@ from libs.dataloader.document import DocumentLoader from core.rag.embedder import TextEmbedding3Small from core.rag.dbhandler.memgraph import MemGraphClient +from loader_config import DOC_LOADER_CONFIGS import asyncio import os @@ -38,12 +39,14 @@ def store(source, doc, chunks, vectors): db.create_vector(vector) async def main(): - loadar_paths = ["/Users/nullchimp/Projects/customer-security-trust/FAQ"] - for path in loadar_paths: - loader = DocumentLoader(path, ['.md']) + for config in DOC_LOADER_CONFIGS: + loader = DocumentLoader(config.path, config.file_extensions or ['.md']) for source, doc, chunks in loader.load_data(): vectors = [] await embedder.process_chunks(chunks, callback=lambda v: vectors.append(v)) + if config.uri_replacement: + old_pattern, new_pattern = config.uri_replacement + source.uri = f"{source.uri.replace(old_pattern, new_pattern)}" store(source, doc, chunks, vectors) asyncio.run(main()) \ No newline at end of file diff --git a/scripts/loader_config.py b/scripts/loader_config.py new file mode 100644 index 0000000..aabb278 --- /dev/null +++ b/scripts/loader_config.py @@ -0,0 +1,40 @@ +from typing import List, Tuple, Optional +from dataclasses import dataclass + +@dataclass +class LoaderConfig: + path: str + file_extensions: Optional[List[str]] = None + uri_replacement: Optional[Tuple[str, str]] = None + + +@dataclass +class WebLoaderConfig: + url: str + uri_replacement: Optional[Tuple[str, str]] = None + + +DOC_LOADER_CONFIGS = [ + LoaderConfig( + path="/Users/nullchimp/Projects/customer-security-trust/FAQ", + file_extensions=['.md'], + uri_replacement=( + "/Users/nullchimp/Projects/customer-security-trust/FAQ", + "https://github.com/github/customer-security-trust/blob/main/FAQ" + ) + ), + LoaderConfig( + path="/Users/nullchimp/Projects/github-docs/content-copilot", + file_extensions=['.md'] + ) +] + +WEB_LOADER_CONFIGS = [ + WebLoaderConfig( + url="http://localhost:4000/en/enterprise-cloud@latest", + uri_replacement=( + "http://localhost:4000", + "https://docs.github.com" + ) + ) +] diff --git a/docker/memgraph.sh b/scripts/memgraph.sh similarity index 99% rename from docker/memgraph.sh rename to scripts/memgraph.sh index e8492e8..558fb85 100755 --- a/docker/memgraph.sh +++ b/scripts/memgraph.sh @@ -7,7 +7,7 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" -cd "$SCRIPT_DIR" +cd "$PROJECT_ROOT/docker" # Function to print colored output print_message() { diff --git a/scripts/url-loader.py b/scripts/url_loader.py similarity index 71% rename from scripts/url-loader.py rename to scripts/url_loader.py index 10ce959..ee56ccf 100644 --- a/scripts/url-loader.py +++ b/scripts/url_loader.py @@ -4,8 +4,8 @@ from core.rag.embedder import TextEmbedding3Small from core.rag.dbhandler.memgraph import MemGraphClient - from libs.dataloader.web import WebLoader +from loader_config import WEB_LOADER_CONFIGS import asyncio import os @@ -19,7 +19,6 @@ print("Connected to Memgraph", db.host, db.port) -loader = WebLoader("http://localhost:4000/en/enterprise-cloud@latest/copilot") embedder = TextEmbedding3Small() vector_store = db.create_vector_store( @@ -45,10 +44,15 @@ def store(source, doc, chunks, vectors): print("### Data stored successfully") async def main(): - for source, doc, chunks in loader.load_data(): - vectors = [] - await embedder.process_chunks(chunks, callback=lambda v: vectors.append(v)) - store(source, doc, chunks, vectors) + for config in WEB_LOADER_CONFIGS: + loader = WebLoader(config.url) + for source, doc, chunks in loader.load_data(): + vectors = [] + await embedder.process_chunks(chunks, callback=lambda v: vectors.append(v)) + if config.uri_replacement: + old_pattern, new_pattern = config.uri_replacement + source.uri = f"{source.uri.replace(old_pattern, new_pattern)}" + store(source, doc, chunks, vectors) db.close() diff --git a/src/agent.py b/src/agent.py index a11fceb..42eaee1 100644 --- a/src/agent.py +++ b/src/agent.py @@ -26,28 +26,41 @@ def __init__(self): # Define enhanced system role with instructions on using all available tools self.system_role = f""" -You are a helpful assistant. -Your Name is Agent Smith. + You are a helpful assistant. + Your Name is Agent Smith. -Whenever you are not sure about something, have a look at the tools available to you. -On GitHub related questions: -- Use the GitHub Knowledgebase tool, which is the only reliable source. -- Only if you cannot find the answer there, use the Google Search tool, which is less reliable. + On GitHub related questions: + - You MUST always use the GitHub Knowledgebase tool, which is the only reliable source. + - Never make up answers, ALWAYS back them up with facts from the GitHub Knowledgebase. -MCP Servers may provide additional tools, which you can use to execute tasks. + On general questions or when the GitHub Knowledgebase does not have the answer: + - You can use the Google Search tool to find information. + - You can also use the Read File tool to read files, Write File tool to write files, and List Files tool to list files. + - If you need to use a tool, you MUST call it explicitly. -You MUST provide the most up-to-date and most accurate information. -You MUST synthesize and cite your sources correctly, but keep responses concise. + On any task that requires external information: + - You MUST use the tools provided to you by MCP Servers. + - You MUST NOT make up answers or provide information without using the tools. + - If you do not know the answer, you MUST say "I don't know" instead of making up an answer. -Today is {date.today().strftime("%d %B %Y")}. -""" + You MUST provide the most up-to-date and most accurate information. + You MUST synthesize and cite your sources correctly, but keep responses concise. + + Today is {date.today().strftime("%d %B %Y")}. + """ + + self.history = [ + {"role": "system", "content": self.system_role} + ] def add_tool(self, tool: Tool) -> None: self.chat.add_tool(tool) async def process_query(self, user_prompt: str) -> str: - messages = [{"role": "system", "content": self.system_role}] - messages.append({"role": "user", "content": user_prompt}) + user_role = {"role": "user", "content": user_prompt} + + messages = list(self.history) + messages.append(user_role) response = await self.chat.send_messages(messages) choices = response.get("choices", []) @@ -67,6 +80,11 @@ async def process_query(self, user_prompt: str) -> str: messages.append(assistant_message) result = assistant_message.get("content", "") + if result: + self.history.append(user_role) + self.history.append(assistant_message) + + pretty_print("History", self.history) return result diff --git a/src/core/rag/embedder/__init__.py b/src/core/rag/embedder/__init__.py index ccedece..e739f6e 100644 --- a/src/core/rag/embedder/__init__.py +++ b/src/core/rag/embedder/__init__.py @@ -22,6 +22,7 @@ async def process_chunk( callback: callable = None ) -> None: """Process a single chunk: generate embedding and store in vector DB""" + try: embedding = await self._make_embedding_request(chunk.content) except Exception as e: @@ -69,11 +70,13 @@ async def _make_embedding_request(self, text: str, retry = 3) -> List[float]: raise ValueError("Failed to get embedding from Azure OpenAI") except Exception as e: - if "429" in str(e) and retry > 1: - await asyncio.sleep(60) # Wait for 1 minute before retrying - return await self._make_embedding_request(text, retry=retry-1) + if retry <= 0: + raise ValueError(f"Failed to get embedding after retries: {str(e)}") + + await asyncio.sleep(5) # Wait for 5 seconds before retrying + if "429" in str(e): + await asyncio.sleep(55) # Wait for a total of 1 minute before retrying - # Re-raise the exception if it's not a 429 error or if retries are exhausted - raise + return await self._make_embedding_request(text, retry=retry-1) from core.rag.embedder.text_embedding_3_small import TextEmbedding3Small \ No newline at end of file diff --git a/src/core/rag/embedder/text_embedding_3_small.py b/src/core/rag/embedder/text_embedding_3_small.py index 15ccca9..6d60f38 100644 --- a/src/core/rag/embedder/text_embedding_3_small.py +++ b/src/core/rag/embedder/text_embedding_3_small.py @@ -1,4 +1,3 @@ -import uuid from core.rag.schema import DocumentChunk from . import EmbeddingService diff --git a/tests/test_core_rag_embedder.py b/tests/test_core_rag_embedder.py index a834c72..b6540af 100644 --- a/tests/test_core_rag_embedder.py +++ b/tests/test_core_rag_embedder.py @@ -142,7 +142,9 @@ async def test_make_embedding_request_retry_429(self, service): result = await service._make_embedding_request("test text", retry=2) assert result == [0.1, 0.2, 0.3] - mock_sleep.assert_called_once_with(60) + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(5) + mock_sleep.assert_any_call(55) assert service._client.make_request.call_count == 2 @pytest.mark.asyncio diff --git a/tests/test_scripts_loader_config.py b/tests/test_scripts_loader_config.py new file mode 100644 index 0000000..2d1ecbe --- /dev/null +++ b/tests/test_scripts_loader_config.py @@ -0,0 +1,114 @@ +import pytest +from unittest.mock import Mock, patch +from scripts.loader_config import DOC_LOADER_CONFIGS, WEB_LOADER_CONFIGS, LoaderConfig, WebLoaderConfig + + +def test_loader_config_structure(): + assert len(DOC_LOADER_CONFIGS) >= 1 + assert all(isinstance(config, LoaderConfig) for config in DOC_LOADER_CONFIGS) + + first_config = DOC_LOADER_CONFIGS[0] + assert first_config.path + assert first_config.file_extensions + assert first_config.uri_replacement + + +def test_web_loader_config_structure(): + assert len(WEB_LOADER_CONFIGS) >= 1 + assert all(isinstance(config, WebLoaderConfig) for config in WEB_LOADER_CONFIGS) + + first_config = WEB_LOADER_CONFIGS[0] + assert first_config.url + assert first_config.uri_replacement + + +def test_loader_config_uri_replacement(): + config = LoaderConfig( + path="/test/path", + file_extensions=['.md'], + uri_replacement=("/old/path", "https://new.url") + ) + + assert config.uri_replacement[0] == "/old/path" + assert config.uri_replacement[1] == "https://new.url" + + +def test_web_loader_config_uri_replacement(): + config = WebLoaderConfig( + url="http://localhost:4000/test", + uri_replacement=("http://localhost:4000", "https://docs.github.com") + ) + + assert config.uri_replacement[0] == "http://localhost:4000" + assert config.uri_replacement[1] == "https://docs.github.com" + + +def test_doc_loader_config_validation(): + """Test that DOC_LOADER_CONFIGS has valid configuration for doc_loader.py logic""" + from scripts.loader_config import DOC_LOADER_CONFIGS + + config = DOC_LOADER_CONFIGS[0] + assert config.path == "/Users/nullchimp/Projects/customer-security-trust/FAQ" + assert config.file_extensions == ['.md'] + assert config.uri_replacement is not None + assert config.uri_replacement == ( + "/Users/nullchimp/Projects/customer-security-trust/FAQ", + "https://github.com/github/customer-security-trust/blob/main/FAQ" + ) + + # Test that the second config doesn't have URI replacement + config2 = DOC_LOADER_CONFIGS[1] + assert config2.path == "/Users/nullchimp/Projects/github-docs/content-copilot" + assert config2.file_extensions == ['.md'] + assert config2.uri_replacement is None + + +def test_doc_loader_uri_replacement_logic(): + config = DOC_LOADER_CONFIGS[0] + + # Simulate source URI that would come from DocumentLoader + mock_source_uri = "/Users/nullchimp/Projects/customer-security-trust/FAQ/security-faq.md" + + if config.uri_replacement: + old_pattern, new_pattern = config.uri_replacement + # This is the actual logic from doc_loader.py + new_uri = mock_source_uri.replace(old_pattern, new_pattern) + + expected_uri = "https://github.com/github/customer-security-trust/blob/main/FAQ/security-faq.md" + assert new_uri == expected_uri + + +def test_web_loader_uri_replacement_logic(): + config = WEB_LOADER_CONFIGS[0] + + # Simulate source URI that would come from WebLoader + mock_source_uri = "http://localhost:4000/en/enterprise-cloud@latest" + mock_source_name = "some-page.md" + + if config.uri_replacement: + old_pattern, new_pattern = config.uri_replacement + # This is the actual logic from url_loader.py + new_uri = f"{mock_source_uri.replace(old_pattern, new_pattern)}" + + expected_uri = "https://docs.github.com/en/enterprise-cloud@latest" + assert new_uri == expected_uri + + +def test_config_has_expected_structure(): + # Test that we have the expected number of configs + assert len(DOC_LOADER_CONFIGS) == 2 + assert len(WEB_LOADER_CONFIGS) == 1 + + # Test first doc loader config (with URI replacement) + first_doc_config = DOC_LOADER_CONFIGS[0] + assert first_doc_config.uri_replacement is not None + + # Test second doc loader config (without URI replacement) + second_doc_config = DOC_LOADER_CONFIGS[1] + assert second_doc_config.path == "/Users/nullchimp/Projects/github-docs/content-copilot" + assert second_doc_config.uri_replacement is None + + # Test web loader config + web_config = WEB_LOADER_CONFIGS[0] + assert web_config.url == "http://localhost:4000/en/enterprise-cloud@latest" + assert web_config.uri_replacement is not None