diff --git a/pyproject.toml b/pyproject.toml index 4001d983..c5c5d33b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,19 +8,14 @@ authors = [ ] requires-python = ">=3.10,<3.14" dependencies = [ - "pydantic>=2.6.1", "lancedb>=0.5.4", - "openai>=1.12.0", - "chromadb==0.5.23", "pytube>=15.0.0", "requests>=2.31.0", "docker>=7.1.0", "crewai>=0.193.2", - "click>=8.1.8", "lancedb>=0.5.4", "tiktoken>=0.8.0", "stagehand>=0.4.1", - "portalocker==2.7.0", "beautifulsoup4>=4.13.4", "pypdf>=5.9.0", "python-docx>=1.2.0", diff --git a/tests/tools/rag/rag_tool_test.py b/tests/tools/rag/rag_tool_test.py index 693cd120..d50d4949 100644 --- a/tests/tools/rag/rag_tool_test.py +++ b/tests/tools/rag/rag_tool_test.py @@ -1,54 +1,176 @@ +"""Tests for RAG tool with mocked embeddings and vector database.""" + from tempfile import TemporaryDirectory -from typing import cast +from typing import Any, cast from pathlib import Path +from unittest.mock import Mock, patch, MagicMock +import pytest from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter from crewai_tools.tools.rag.rag_tool import RagTool -def test_rag_tool_initialization(): +@patch('crewai_tools.adapters.crewai_rag_adapter.get_rag_client') +@patch('crewai_tools.adapters.crewai_rag_adapter.create_client') +def test_rag_tool_initialization( + mock_create_client: Mock, + mock_get_rag_client: Mock +) -> None: """Test that RagTool initializes with CrewAI adapter by default.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_get_rag_client.return_value = mock_client + mock_create_client.return_value = mock_client + class MyTool(RagTool): pass tool = MyTool() assert tool.adapter is not None assert isinstance(tool.adapter, CrewAIRagAdapter) - + adapter = cast(CrewAIRagAdapter, tool.adapter) assert adapter.collection_name == "rag_tool_collection" assert adapter._client is not None -def test_rag_tool_add_and_query(): +@patch('crewai_tools.adapters.crewai_rag_adapter.get_rag_client') +@patch('crewai_tools.adapters.crewai_rag_adapter.create_client') +def test_rag_tool_add_and_query( + mock_create_client: Mock, + mock_get_rag_client: Mock +) -> None: """Test adding content and querying with RagTool.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_client.add_documents = MagicMock(return_value=None) + mock_client.search = MagicMock(return_value=[ + {"content": "The sky is blue on a clear day.", "metadata": {}, "score": 0.9} + ]) + mock_get_rag_client.return_value = mock_client + mock_create_client.return_value = mock_client + class MyTool(RagTool): pass - + tool = MyTool() - + tool.add("The sky is blue on a clear day.") tool.add("Machine learning is a subset of artificial intelligence.") - + + # Verify documents were added + assert mock_client.add_documents.call_count == 2 + result = tool._run(query="What color is the sky?") assert "Relevant Content:" in result - + assert "The sky is blue" in result + + mock_client.search.return_value = [ + {"content": "Machine learning is a subset of artificial intelligence.", "metadata": {}, "score": 0.85} + ] + result = tool._run(query="Tell me about machine learning") assert "Relevant Content:" in result + assert "Machine learning" in result -def test_rag_tool_with_file(): +@patch('crewai_tools.adapters.crewai_rag_adapter.get_rag_client') +@patch('crewai_tools.adapters.crewai_rag_adapter.create_client') +def test_rag_tool_with_file( + mock_create_client: Mock, + mock_get_rag_client: Mock +) -> None: """Test RagTool with file content.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_client.add_documents = MagicMock(return_value=None) + mock_client.search = MagicMock(return_value=[ + {"content": "Python is a programming language known for its simplicity.", "metadata": {"file_path": "test.txt"}, "score": 0.95} + ]) + mock_get_rag_client.return_value = mock_client + mock_create_client.return_value = mock_client + with TemporaryDirectory() as tmpdir: test_file = Path(tmpdir) / "test.txt" test_file.write_text("Python is a programming language known for its simplicity.") - + class MyTool(RagTool): pass - + tool = MyTool() tool.add(str(test_file)) - + + assert mock_client.add_documents.called + result = tool._run(query="What is Python?") assert "Relevant Content:" in result + assert "Python is a programming language" in result + + +@patch('crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function') +@patch('crewai_tools.adapters.crewai_rag_adapter.create_client') +def test_rag_tool_with_custom_embeddings( + mock_create_client: Mock, + mock_create_embedding: Mock +) -> None: + """Test RagTool with custom embeddings configuration to ensure no API calls.""" + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.2] * 1536] + mock_create_embedding.return_value = mock_embedding_func + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_client.add_documents = MagicMock(return_value=None) + mock_client.search = MagicMock(return_value=[ + {"content": "Test content", "metadata": {}, "score": 0.8} + ]) + mock_create_client.return_value = mock_client + + class MyTool(RagTool): + pass + + config = { + "vectordb": { + "provider": "chromadb", + "config": {} + }, + "embedding_model": { + "provider": "openai", + "config": { + "model": "text-embedding-3-small" + } + } + } + + tool = MyTool(config=config) + tool.add("Test content") + + result = tool._run(query="Test query") + assert "Relevant Content:" in result + assert "Test content" in result + + mock_create_embedding.assert_called() + + +@patch('crewai_tools.adapters.crewai_rag_adapter.get_rag_client') +@patch('crewai_tools.adapters.crewai_rag_adapter.create_client') +def test_rag_tool_no_results( + mock_create_client: Mock, + mock_get_rag_client: Mock +) -> None: + """Test RagTool when no relevant content is found.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_client.search = MagicMock(return_value=[]) + mock_get_rag_client.return_value = mock_client + mock_create_client.return_value = mock_client + + class MyTool(RagTool): + pass + + tool = MyTool() + + result = tool._run(query="Non-existent content") + assert "Relevant Content:" in result + assert "No relevant content found" in result \ No newline at end of file diff --git a/uv.lock b/uv.lock index cf083faf..44529df9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy' and sys_platform == 'darwin'", @@ -1005,14 +1005,9 @@ version = "0.73.1" source = { editable = "." } dependencies = [ { name = "beautifulsoup4" }, - { name = "chromadb" }, - { name = "click" }, { name = "crewai" }, { name = "docker" }, { name = "lancedb" }, - { name = "openai" }, - { name = "portalocker" }, - { name = "pydantic" }, { name = "pypdf" }, { name = "python-docx" }, { name = "pytube" }, @@ -1154,8 +1149,6 @@ requires-dist = [ { name = "beautifulsoup4", marker = "extra == 'bedrock'", specifier = ">=4.13.4" }, { name = "bedrock-agentcore", marker = "extra == 'bedrock'", specifier = ">=0.1.0" }, { name = "browserbase", marker = "extra == 'browserbase'", specifier = ">=1.0.5" }, - { name = "chromadb", specifier = "==0.5.23" }, - { name = "click", specifier = ">=8.1.8" }, { name = "composio-core", marker = "extra == 'composio-core'", specifier = ">=0.6.11.post1" }, { name = "contextual-client", marker = "extra == 'contextual'", specifier = ">=0.1.0" }, { name = "couchbase", marker = "extra == 'couchbase'", specifier = ">=4.3.5" }, @@ -1176,13 +1169,10 @@ requires-dist = [ { name = "multion", marker = "extra == 'multion'", specifier = ">=1.1.0" }, { name = "nest-asyncio", marker = "extra == 'bedrock'", specifier = ">=1.6.0" }, { name = "nest-asyncio", marker = "extra == 'contextual'", specifier = ">=1.6.0" }, - { name = "openai", specifier = ">=1.12.0" }, { name = "oxylabs", marker = "extra == 'oxylabs'", specifier = "==2.0.0" }, { name = "patronus", marker = "extra == 'patronus'", specifier = ">=0.0.16" }, { name = "playwright", marker = "extra == 'bedrock'", specifier = ">=1.52.0" }, - { name = "portalocker", specifier = "==2.7.0" }, { name = "psycopg2-binary", marker = "extra == 'postgresql'", specifier = ">=2.9.10" }, - { name = "pydantic", specifier = ">=2.6.1" }, { name = "pygithub", marker = "extra == 'github'", specifier = "==1.59.1" }, { name = "pymongo", marker = "extra == 'mongodb'", specifier = ">=4.13" }, { name = "pymysql", marker = "extra == 'mysql'", specifier = ">=1.1.1" },