diff --git a/.env.template b/.env.template index 3142c5b..ad6b09e 100644 --- a/.env.template +++ b/.env.template @@ -45,6 +45,11 @@ COSMOSDB_PARTITION_KEY="/id" SQL_DATABASE_URI="sqlite:///template_langgraph.db" # SQL_DATABASE_URI="postgresql://user:password@localhost:5432/db" +# Azure AI Search Settings +AI_SEARCH_ENDPOINT="https://xxx.search.windows.net/" +AI_SEARCH_KEY="xxx" +AI_SEARCH_INDEX_NAME="kabuto" + # --------- # Utilities # --------- diff --git a/docs/references.md b/docs/references.md index 576c78c..900a287 100644 --- a/docs/references.md +++ b/docs/references.md @@ -21,6 +21,7 @@ - [CSVLoader](https://python.langchain.com/docs/how_to/document_loader_csv/) - [Qdrant](https://github.com/qdrant/qdrant) - [Azure Cosmos DB No SQL](https://python.langchain.com/docs/integrations/vectorstores/azure_cosmos_db_no_sql/) +- [Azure AI Search](https://python.langchain.com/docs/integrations/vectorstores/azuresearch/) ### Services diff --git a/pyproject.toml b/pyproject.toml index db63476..e2aedd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,8 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "azure-cosmos>=4.9.0", + "azure-identity>=1.23.1", + "azure-search-documents>=11.5.3", "elasticsearch>=9.1.0", "fastapi[standard]>=0.116.1", "httpx>=0.28.1", diff --git a/scripts/ai_search_operator.py b/scripts/ai_search_operator.py new file mode 100644 index 0000000..c52da87 --- /dev/null +++ b/scripts/ai_search_operator.py @@ -0,0 +1,102 @@ +import logging + +import typer +from dotenv import load_dotenv + +from template_langgraph.loggers import get_logger +from template_langgraph.tools.ai_search_tool import AiSearchClientWrapper +from template_langgraph.utilities.csv_loaders import CsvLoaderWrapper +from template_langgraph.utilities.pdf_loaders import PdfLoaderWrapper + +# Initialize the Typer application +app = typer.Typer( + add_completion=False, + help="AI Search operator CLI", +) + +# Set up logging +logger = get_logger(__name__) + + +@app.command() +def add_documents( + verbose: bool = typer.Option( + False, + "--verbose", + "-v", + help="Enable verbose output", + ), +): + # Set up logging + if verbose: + logger.setLevel(logging.DEBUG) + + # Load documents from PDF files + pdf_documents = PdfLoaderWrapper().load_pdf_docs() + logger.info(f"Loaded {len(pdf_documents)} documents from PDF.") + + # Load documents from CSV files + csv_documents = CsvLoaderWrapper().load_csv_docs() + logger.info(f"Loaded {len(csv_documents)} documents from CSV.") + + # Combine all documents + documents = pdf_documents + csv_documents + logger.info(f"Total documents to add: {len(documents)}") + + # Add documents to AI Search + ai_search_client = AiSearchClientWrapper() + ids = ai_search_client.add_documents( + documents=documents, + ) + logger.info(f"Added {len(ids)} documents to AI Search.") + for id in ids: + logger.debug(f"Added document ID: {id}") + + +@app.command() +def similarity_search( + query: str = typer.Option( + "禅モード", + "--query", + "-q", + help="Query to search in the AI Search index", + ), + k: int = typer.Option( + 5, + "--k", + "-k", + help="Number of results to return from the similarity search", + ), + verbose: bool = typer.Option( + False, + "--verbose", + "-v", + help="Enable verbose output", + ), +): + # Set up logging + if verbose: + logger.setLevel(logging.DEBUG) + + logger.info(f"Searching AI Search with query: {query}") + + # Perform similarity search + ai_search_client = AiSearchClientWrapper() + documents = ai_search_client.similarity_search( + query=query, + k=k, # Number of results to return + ) + logger.info(f"Found {len(documents)} results for query: {query}") + + # Log the results + for i, document in enumerate(documents, start=1): + logger.debug("-" * 40) + logger.debug(f"#{i}: {document.model_dump_json(indent=2)}") + + +if __name__ == "__main__": + load_dotenv( + override=True, + verbose=True, + ) + app() diff --git a/template_langgraph/tools/ai_search_tool.py b/template_langgraph/tools/ai_search_tool.py new file mode 100644 index 0000000..2e2d923 --- /dev/null +++ b/template_langgraph/tools/ai_search_tool.py @@ -0,0 +1,105 @@ +from functools import lru_cache + +from langchain_community.vectorstores.azuresearch import AzureSearch +from langchain_core.documents import Document +from langchain_core.tools import tool +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict + +from template_langgraph.llms.azure_openais import AzureOpenAiWrapper + + +class Settings(BaseSettings): + ai_search_key: str = "" + ai_search_endpoint: str = "" + ai_search_index_name: str = "" + + model_config = SettingsConfigDict( + env_file=".env", + env_ignore_empty=True, + extra="ignore", + ) + + +@lru_cache +def get_ai_search_settings() -> Settings: + """Get AI Search settings.""" + return Settings() + + +class AiSearchClientWrapper: + def __init__( + self, + settings: Settings = None, + ): + if settings is None: + settings = get_ai_search_settings() + self.vector_store: AzureSearch = AzureSearch( + azure_search_endpoint=settings.ai_search_endpoint, + azure_search_key=settings.ai_search_key, + index_name=settings.ai_search_index_name, + embedding_function=AzureOpenAiWrapper().embedding_model.embed_query, + ) + + def add_documents( + self, + documents: list[Document], + ) -> list[str]: + """Add documents to a Cosmos DB container.""" + return self.vector_store.add_documents( + documents=documents, + ) + + def similarity_search( + self, + query: str, + k: int = 5, + ) -> list[Document]: + """Perform a similarity search in the Cosmos DB index.""" + return self.vector_store.similarity_search( + query=query, + k=k, # Number of results to return + ) + + +class AiSearchInput(BaseModel): + query: str = Field( + default="禅モード", + description="Query to search in the AI Search index", + ) + k: int = Field( + default=5, + description="Number of results to return from the similarity search", + ) + + +class AiSearchOutput(BaseModel): + content: str = Field(description="Content of the document") + id: str = Field(description="ID of the document") + + +@tool(args_schema=AiSearchInput) +def search_ai_search(query: str, k: int = 5) -> list[AiSearchOutput]: + """Search for similar documents in AI Search index. + + Args: + query: The search query string + k: Number of results to return (default: 5) + + Returns: + AiSearchOutput: A Pydantic model containing the search results + """ + wrapper = AiSearchClientWrapper() + documents = wrapper.similarity_search( + query=query, + k=k, + ) + outputs = [] + for document in documents: + outputs.append( + { + "content": document.page_content, + "id": document.id, + } + ) + return outputs diff --git a/template_langgraph/tools/common.py b/template_langgraph/tools/common.py index a7c3621..eed5c03 100644 --- a/template_langgraph/tools/common.py +++ b/template_langgraph/tools/common.py @@ -1,5 +1,6 @@ from template_langgraph.llms.azure_openais import AzureOpenAiWrapper from template_langgraph.loggers import get_logger +from template_langgraph.tools.ai_search_tool import search_ai_search from template_langgraph.tools.cosmosdb_tool import search_cosmosdb from template_langgraph.tools.dify_tool import run_dify_workflow from template_langgraph.tools.elasticsearch_tool import search_elasticsearch @@ -18,6 +19,7 @@ def get_default_tools(): logger.error(f"Error occurred while getting SQL database tools: {e}") sql_database_tools = [] return [ + search_ai_search, search_cosmosdb, run_dify_workflow, search_qdrant, diff --git a/uv.lock b/uv.lock index 8f75040..580e371 100644 --- a/uv.lock +++ b/uv.lock @@ -316,6 +316,15 @@ opentelemetry = [ { name = "azure-core-tracing-opentelemetry" }, ] +[[package]] +name = "azure-common" +version = "1.1.28" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3e/71/f6f71a276e2e69264a97ad39ef850dca0a04fce67b12570730cb38d0ccac/azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3", size = 20914, upload-time = "2022-02-03T19:39:44.373Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/55/7f118b9c1b23ec15ca05d15a578d8207aa1706bc6f7c87218efffbbf875d/azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad", size = 14462, upload-time = "2022-02-03T19:39:42.417Z" }, +] + [[package]] name = "azure-core" version = "1.35.0" @@ -372,6 +381,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/b3/e2d7ab810eb68575a5c7569b03c0228b8f4ce927ffa6211471b526f270c9/azure_identity-1.23.1-py3-none-any.whl", hash = "sha256:7eed28baa0097a47e3fb53bd35a63b769e6b085bb3cb616dfce2b67f28a004a1", size = 186810, upload-time = "2025-07-15T19:16:40.184Z" }, ] +[[package]] +name = "azure-search-documents" +version = "11.5.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-common" }, + { name = "azure-core" }, + { name = "isodate" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/11/9ecde2bd9e6c00cc0e3f312ab096a33d333f8ba40c847f01f94d524895fe/azure_search_documents-11.5.3.tar.gz", hash = "sha256:6931149ec0db90485d78648407f18ea4271420473c7cb646bf87790374439989", size = 300353, upload-time = "2025-06-25T16:48:58.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/f5/0f6b52567cbb33f1efba13060514ed7088a86de84d74b77cda17d278bcd9/azure_search_documents-11.5.3-py3-none-any.whl", hash = "sha256:110617751c6c8bd50b1f0af2b00a478bd4fbaf4e2f0387e3454c26ec3eb433d6", size = 298772, upload-time = "2025-06-25T16:49:00.764Z" }, +] + [[package]] name = "babel" version = "2.17.0" @@ -4508,6 +4532,8 @@ version = "0.0.1" source = { editable = "." } dependencies = [ { name = "azure-cosmos" }, + { name = "azure-identity" }, + { name = "azure-search-documents" }, { name = "elasticsearch" }, { name = "fastapi", extra = ["standard"] }, { name = "httpx" }, @@ -4547,6 +4573,8 @@ docs = [ [package.metadata] requires-dist = [ { name = "azure-cosmos", specifier = ">=4.9.0" }, + { name = "azure-identity", specifier = ">=1.23.1" }, + { name = "azure-search-documents", specifier = ">=11.5.3" }, { name = "elasticsearch", specifier = ">=9.1.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.116.1" }, { name = "httpx", specifier = ">=0.28.1" },